mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
715 Commits
v1.1.1
...
cache_lear
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a87910bc65 | ||
|
|
f48662e863 | ||
|
|
8d8bfc7f54 | ||
|
|
8e15dcd289 | ||
|
|
586ac9d8a6 | ||
|
|
288bbc7128 | ||
|
|
5002ac74dc | ||
|
|
863a6ba597 | ||
|
|
b08bc1470d | ||
|
|
94b57e9677 | ||
|
|
3fb037d33a | ||
|
|
b3b63fef3e | ||
|
|
f6d85f3c2e | ||
|
|
2f22e598b7 | ||
|
|
888caf8b88 | ||
|
|
b6e39c97af | ||
|
|
02124c4034 | ||
|
|
fddc98ff16 | ||
|
|
0dfcd25cf3 | ||
|
|
ff10fde47f | ||
|
|
dc94614c80 | ||
|
|
e56a4d5730 | ||
|
|
3f8468893a | ||
|
|
1b47e1dc22 | ||
|
|
b0bf78e915 | ||
|
|
abdf66d09e | ||
|
|
27b1fe240b | ||
|
|
1635897516 | ||
|
|
8d172127cd | ||
|
|
fccb1ecdd7 | ||
|
|
c0f7e1db7c | ||
|
|
53890bafa4 | ||
|
|
6886f7ba35 | ||
|
|
afd48cd706 | ||
|
|
24b68c2392 | ||
|
|
280ff7cca6 | ||
|
|
b4b62e2f7c | ||
|
|
051b957adb | ||
|
|
ca9b5e64ea | ||
|
|
6d1be405b9 | ||
|
|
25c3a3d3e2 | ||
|
|
49bc84f78e | ||
|
|
25a9e75030 | ||
|
|
2a7ac73eb5 | ||
|
|
f4f991d409 | ||
|
|
a781138413 | ||
|
|
91a5623976 | ||
|
|
28cd355aba | ||
|
|
005389fca7 | ||
|
|
a6282056eb | ||
|
|
21a6eb8e2f | ||
|
|
98ab238340 | ||
|
|
2070bbd925 | ||
|
|
1c8a0f8317 | ||
|
|
9f07d65ebb | ||
|
|
5f1d5adfce | ||
|
|
4f23caa55f | ||
|
|
b4f6a4de6c | ||
|
|
53fe42af1b | ||
|
|
ee9a3b4405 | ||
|
|
b1a2782ad7 | ||
|
|
8d303b47e9 | ||
|
|
00da4b6c4f | ||
|
|
22695e9be0 | ||
|
|
3140199c96 | ||
|
|
98290190ec | ||
|
|
3f4de2cc7f | ||
|
|
8d0df403ca | ||
|
|
4e9db263b0 | ||
|
|
d12bf71bcc | ||
|
|
35e0776022 | ||
|
|
ffb7a138f7 | ||
|
|
548304667f | ||
|
|
273143136c | ||
|
|
030ebe649a | ||
|
|
90921d2293 | ||
|
|
b61131c693 | ||
|
|
37fbb3248a | ||
|
|
d13f533f42 | ||
|
|
b3cc652dea | ||
|
|
d879d66c62 | ||
|
|
848bfd6993 | ||
|
|
269da09f6e | ||
|
|
e30514a00c | ||
|
|
3743b1307c | ||
|
|
a835df984c | ||
|
|
3e4b47e424 | ||
|
|
dd8d902624 | ||
|
|
a8b340c098 | ||
|
|
88497b5c13 | ||
|
|
1e90c72d94 | ||
|
|
3dd82a738e | ||
|
|
8ad2d9884b | ||
|
|
70f531b724 | ||
|
|
37c2868b61 | ||
|
|
a18e6233b5 | ||
|
|
2336d5f6b3 | ||
|
|
b6ccb362b9 | ||
|
|
ae52d93694 | ||
|
|
ad91d41601 | ||
|
|
dce77ec4d1 | ||
|
|
5c0b07d939 | ||
|
|
19e429d889 | ||
|
|
209a350c0f | ||
|
|
a3c2744a43 | ||
|
|
55e8346da3 | ||
|
|
b7979b2633 | ||
|
|
c90aaa2798 | ||
|
|
0c617d5d9e | ||
|
|
fd87b72754 | ||
|
|
db75508ba0 | ||
|
|
acba342a63 | ||
|
|
d16877e695 | ||
|
|
e99cdcf3b8 | ||
|
|
a236a17f17 | ||
|
|
03e530dc39 | ||
|
|
6be244233a | ||
|
|
544c391936 | ||
|
|
f4d06ce3fc | ||
|
|
ffedb9eb52 | ||
|
|
381067515c | ||
|
|
00f2d1aa5d | ||
|
|
8cc3bece6d | ||
|
|
f4bf592064 | ||
|
|
3235393fb5 | ||
|
|
3b662da31e | ||
|
|
19ce3048c1 | ||
|
|
de0aa946f7 | ||
|
|
f376202a49 | ||
|
|
a13ecfc46b | ||
|
|
10a1853eda | ||
|
|
0efab85674 | ||
|
|
f45a0ffd02 | ||
|
|
8ba528a8f6 | ||
|
|
dd479e5bff | ||
|
|
bac39b1cd2 | ||
|
|
c1c9a4853b | ||
|
|
3ee5f53a36 | ||
|
|
32449a6aa0 | ||
|
|
a6884f6b3a | ||
|
|
b078666640 | ||
|
|
7604ca1e52 | ||
|
|
62c3d406d9 | ||
|
|
5745c9f200 | ||
|
|
86829120c2 | ||
|
|
60ac96525b | ||
|
|
07b1f5702f | ||
|
|
507e7e5d36 | ||
|
|
ab8580f77e | ||
|
|
6454259853 | ||
|
|
9cc1697d4d | ||
|
|
c758769a02 | ||
|
|
a5935e973a | ||
|
|
9834d72e4d | ||
|
|
01234e59c0 | ||
|
|
8f1d10fb43 | ||
|
|
20e1aaf908 | ||
|
|
c6722b3f56 | ||
|
|
11315d7a40 | ||
|
|
68d97a9844 | ||
|
|
4629d4cf9e | ||
|
|
3cb5cec906 | ||
|
|
b7e16b9034 | ||
|
|
83d1e7361f | ||
|
|
1547c3f786 | ||
|
|
bfaaf12bf4 | ||
|
|
47545e1aab | ||
|
|
7c6905a432 | ||
|
|
2883bc1b76 | ||
|
|
78d8842ddf | ||
|
|
5821a664a0 | ||
|
|
ab9aa1a087 | ||
|
|
a4d34d9f3d | ||
|
|
127cc9007a | ||
|
|
e1f5db5f5c | ||
|
|
e316fb717f | ||
|
|
64c5139502 | ||
|
|
5da9611a74 | ||
|
|
733750d01b | ||
|
|
edc95359d0 | ||
|
|
f2d0241e26 | ||
|
|
7b5d7f4af5 | ||
|
|
1fa9a6c60c | ||
|
|
51efa128d3 | ||
|
|
421c6a5fce | ||
|
|
864080d8f2 | ||
|
|
ba372dd295 | ||
|
|
1ceb02f673 | ||
|
|
30f93161fb | ||
|
|
3ee3cc3104 | ||
|
|
c2218f5c73 | ||
|
|
72af7122b3 | ||
|
|
afd101f345 | ||
|
|
1313f4dd63 | ||
|
|
8332ecebb7 | ||
|
|
401d7d74a5 | ||
|
|
b8d7d55568 | ||
|
|
a30ed9093f | ||
|
|
b73e713028 | ||
|
|
e0eabaa426 | ||
|
|
538017177a | ||
|
|
30292d9411 | ||
|
|
b168d7aa8b | ||
|
|
8ea45b0daa | ||
|
|
0a1c172a00 | ||
|
|
77fac2a03f | ||
|
|
084bc2fc78 | ||
|
|
c63d474b60 | ||
|
|
7540568156 | ||
|
|
c5d426c254 | ||
|
|
a36f2f6032 | ||
|
|
ed256ef8be | ||
|
|
15079a6cb8 | ||
|
|
c084d6377b | ||
|
|
e9bc42f233 | ||
|
|
0d6de58af9 | ||
|
|
acbf932974 | ||
|
|
9d64ed7042 | ||
|
|
0b4b337e9a | ||
|
|
99908d9a1c | ||
|
|
73ced7a46d | ||
|
|
32b8b9b51e | ||
|
|
f6534a5b63 | ||
|
|
034c9b6c60 | ||
|
|
76335e0fe5 | ||
|
|
c0b589d934 | ||
|
|
833ba1e1fa | ||
|
|
7a5974d964 | ||
|
|
b0abdaffb4 | ||
|
|
e9f29bc402 | ||
|
|
1a7f482fbd | ||
|
|
3a0d51d100 | ||
|
|
bffdb901ed | ||
|
|
d93e8738cd | ||
|
|
7e5ce5d5c9 | ||
|
|
7aef554d83 | ||
|
|
090074e395 | ||
|
|
2dcdeefca8 | ||
|
|
452a6ca5cf | ||
|
|
d6cf20ef33 | ||
|
|
efdd6a59b6 | ||
|
|
42ec7b08eb | ||
|
|
d049fb6d1d | ||
|
|
144365b07d | ||
|
|
cb8de6be1b | ||
|
|
8c13362dcf | ||
|
|
c13fd7e0ee | ||
|
|
958ebf1352 | ||
|
|
b6da77e468 | ||
|
|
260e32217f | ||
|
|
5cee326f92 | ||
|
|
1d240994e7 | ||
|
|
a0bae07825 | ||
|
|
ff71720297 | ||
|
|
dea85643e6 | ||
|
|
6a46f32afe | ||
|
|
4641d0f360 | ||
|
|
826bab5962 | ||
|
|
5b6d112c15 | ||
|
|
febdaf6067 | ||
|
|
0a78bb9d38 | ||
|
|
9cea10cc69 | ||
|
|
caa17da5b9 | ||
|
|
fdeb363fa2 | ||
|
|
4147473c81 | ||
|
|
8a0bd7c377 | ||
|
|
b541b9bed2 | ||
|
|
419d47c195 | ||
|
|
ac2e859960 | ||
|
|
6663dca015 | ||
|
|
86e509ad31 | ||
|
|
8fcfa1dd2d | ||
|
|
2b7a2548b4 | ||
|
|
f0916e6bae | ||
|
|
822e80ec2f | ||
|
|
04e39f7de5 | ||
|
|
ce0b948655 | ||
|
|
c795e35142 | ||
|
|
f7c01f1367 | ||
|
|
cb49f0283f | ||
|
|
6a45815b23 | ||
|
|
8dae8d7bc8 | ||
|
|
f6418004bb | ||
|
|
c4b97cd591 | ||
|
|
b6d1ff01e0 | ||
|
|
0d81626fe7 | ||
|
|
e3f47a799b | ||
|
|
e014cad820 | ||
|
|
89bf3ce5cf | ||
|
|
3ebe118f23 | ||
|
|
7f719cefe6 | ||
|
|
46bd05b54d | ||
|
|
613dafbd09 | ||
|
|
952933eeb1 | ||
|
|
c0172e70b1 | ||
|
|
6ab426e641 | ||
|
|
d0467a7e8d | ||
|
|
36838a05ee | ||
|
|
5e6f9f89f1 | ||
|
|
2dad9a319c | ||
|
|
9ec0652339 | ||
|
|
7e348083ae | ||
|
|
29b12b2f4e | ||
|
|
b3f57ed920 | ||
|
|
c9fea729d8 | ||
|
|
9d0683df25 | ||
|
|
838b8109b1 | ||
|
|
3a9621f6da | ||
|
|
fff2c89360 | ||
|
|
ce61bef2b0 | ||
|
|
123f6dbadb | ||
|
|
f9ce261a0e | ||
|
|
d93de98a21 | ||
|
|
ad1da43476 | ||
|
|
398b1dbd7a | ||
|
|
9f6922bba9 | ||
|
|
f11a91e610 | ||
|
|
7ed09bb78d | ||
|
|
ac931856d5 | ||
|
|
2d09318236 | ||
|
|
7dc49bd036 | ||
|
|
4d16bdf853 | ||
|
|
01a1f48f70 | ||
|
|
6a9d875d65 | ||
|
|
f1c96d31b4 | ||
|
|
aafcca8d77 | ||
|
|
bf369cad4d | ||
|
|
024fdad76d | ||
|
|
e1c2eda5f5 | ||
|
|
0b574cc0c2 | ||
|
|
3212c83398 | ||
|
|
49f9a11eb3 | ||
|
|
fa36739f01 | ||
|
|
42e9764b60 | ||
|
|
f7f5c07570 | ||
|
|
ec1a936624 | ||
|
|
6e6136586c | ||
|
|
34766863f8 | ||
|
|
1d76d5e828 | ||
|
|
250540a398 | ||
|
|
46f3c38c37 | ||
|
|
9a8982efb1 | ||
|
|
3c815cce4b | ||
|
|
39d199c8bb | ||
|
|
f5506d1e13 | ||
|
|
166a8734fe | ||
|
|
b2273ec568 | ||
|
|
89c4e3bdb6 | ||
|
|
051ebf3439 | ||
|
|
7cfadc2ca8 | ||
|
|
32cf5d32ce | ||
|
|
4f7c3b6a1e | ||
|
|
57128dc89f | ||
|
|
d20680baae | ||
|
|
970403f78e | ||
|
|
bee2a969e5 | ||
|
|
2803ffcb38 | ||
|
|
d3224e1fdc | ||
|
|
3c2f85606f | ||
|
|
1f25ad416b | ||
|
|
d0b9b25db7 | ||
|
|
ef09db69cd | ||
|
|
84ede171fd | ||
|
|
6f4e38276e | ||
|
|
a3b67436a6 | ||
|
|
829ca3414b | ||
|
|
3915bc3ee6 | ||
|
|
4299c999b5 | ||
|
|
6bae70eee0 | ||
|
|
6452edb738 | ||
|
|
bc739c78cd | ||
|
|
2feaeb1a64 | ||
|
|
09360cf4f5 | ||
|
|
26461c1963 | ||
|
|
0412fc7232 | ||
|
|
8d2f6ad32e | ||
|
|
1625894694 | ||
|
|
c35f2d8bda | ||
|
|
a8ee7ec9ef | ||
|
|
46d390cf8a | ||
|
|
6b8e3880ff | ||
|
|
c1c3be2420 | ||
|
|
b2554db100 | ||
|
|
b63f81c6e3 | ||
|
|
cb2caa3a36 | ||
|
|
f0ea049faa | ||
|
|
0954e8a017 | ||
|
|
e4178e2501 | ||
|
|
0b860abf1b | ||
|
|
8c558b3526 | ||
|
|
aef982a53c | ||
|
|
db124fa6bc | ||
|
|
2ed3860085 | ||
|
|
87ab7d020b | ||
|
|
03c8fd5e61 | ||
|
|
9c51623fc2 | ||
|
|
8ec545d70c | ||
|
|
79fa8607dc | ||
|
|
7df48fc2b5 | ||
|
|
8ef91b3672 | ||
|
|
2860470b4e | ||
|
|
c125728ce0 | ||
|
|
63eaa9e7ea | ||
|
|
158567ca20 | ||
|
|
de4e2703ca | ||
|
|
9e683bfe25 | ||
|
|
0befa05014 | ||
|
|
283f35447a | ||
|
|
c35414a652 | ||
|
|
68aafab09e | ||
|
|
29663b25a6 | ||
|
|
2861ec4d9f | ||
|
|
729c512c66 | ||
|
|
2af3a6f6a2 | ||
|
|
05dba91f79 | ||
|
|
b8f05bb342 | ||
|
|
5f68727ad3 | ||
|
|
bba44173d2 | ||
|
|
9015d08927 | ||
|
|
1dfa32f0ae | ||
|
|
c98e31fee3 | ||
|
|
f3d2470e84 | ||
|
|
4ad6bd4e23 | ||
|
|
3aed244c6f | ||
|
|
783c435d88 | ||
|
|
cd1ba7281b | ||
|
|
970ff12ff5 | ||
|
|
2827b60330 | ||
|
|
b3df7e5e21 | ||
|
|
c18b5a0c71 | ||
|
|
b9f7d08219 | ||
|
|
11ea986e67 | ||
|
|
b06066f25b | ||
|
|
0b3400bca3 | ||
|
|
0d509241c0 | ||
|
|
ebeda32215 | ||
|
|
ff95c56884 | ||
|
|
2871535f3b | ||
|
|
e3c5d2540b | ||
|
|
22705a44b4 | ||
|
|
43a8d9768c | ||
|
|
dbee3a1ae0 | ||
|
|
f1f00c4255 | ||
|
|
c05b1a2fd0 | ||
|
|
55951590f5 | ||
|
|
1384de0353 | ||
|
|
05c6b49b90 | ||
|
|
d19fcc8c04 | ||
|
|
af6b1d4246 | ||
|
|
cbd10fb27d | ||
|
|
836fa5c957 | ||
|
|
dc066aca2d | ||
|
|
44f6ffbf56 | ||
|
|
0a24d0819f | ||
|
|
f0106cd48c | ||
|
|
dee4075380 | ||
|
|
a692389df0 | ||
|
|
629e9be4ce | ||
|
|
3a3d9010b8 | ||
|
|
a25334b352 | ||
|
|
00279a8375 | ||
|
|
89397c755a | ||
|
|
77676b5cea | ||
|
|
0f4b08daa3 | ||
|
|
63b2c51e11 | ||
|
|
8a9dbbd3ba | ||
|
|
22d28665fe | ||
|
|
1363a0559f | ||
|
|
9cb887015b | ||
|
|
789dade026 | ||
|
|
9bb51fe879 | ||
|
|
d9c812818d | ||
|
|
c8e9a96196 | ||
|
|
6143af4654 | ||
|
|
9458e382b0 | ||
|
|
4f2d9226cf | ||
|
|
f688a469b1 | ||
|
|
c8ea3b3356 | ||
|
|
6e9472b470 | ||
|
|
a5c03c5272 | ||
|
|
8068ac2592 | ||
|
|
5f80e7ac5e | ||
|
|
157e0be49d | ||
|
|
3dbe271aab | ||
|
|
44e2eecdf1 | ||
|
|
8c226e83a6 | ||
|
|
009f26bb40 | ||
|
|
fcf2fbc07f | ||
|
|
b603acd36a | ||
|
|
6c8bb6438b | ||
|
|
8072d3839d | ||
|
|
c8ad643374 | ||
|
|
31f9df5e62 | ||
|
|
e2f415524a | ||
|
|
3eb7e7530e | ||
|
|
916aa54595 | ||
|
|
6ddbd43f7b | ||
|
|
a37a83ecc3 | ||
|
|
f2a0d0c85f | ||
|
|
93194f44e8 | ||
|
|
c4e5033532 | ||
|
|
cc6cd26733 | ||
|
|
1113d305d1 | ||
|
|
6d5f8b7423 | ||
|
|
1b3c204d20 | ||
|
|
1788d50f0a | ||
|
|
e7a21dbf0b | ||
|
|
3b3e1e4d44 | ||
|
|
24426e3a32 | ||
|
|
31369bab15 | ||
|
|
551721658b | ||
|
|
46f052375f | ||
|
|
c2d35a2157 | ||
|
|
4c052e42bc | ||
|
|
a88613555d | ||
|
|
c164519ef1 | ||
|
|
afff5ffb21 | ||
|
|
a8481fd5e1 | ||
|
|
8584e50309 | ||
|
|
9f3e02f167 | ||
|
|
7ad9b9aecc | ||
|
|
b6a111d3a2 | ||
|
|
bd6f2695a9 | ||
|
|
6eecc9d442 | ||
|
|
35269783d7 | ||
|
|
9534a78167 | ||
|
|
830b1b7202 | ||
|
|
436a91e0c9 | ||
|
|
40760ab88b | ||
|
|
8badd63a2d | ||
|
|
b1afff1728 | ||
|
|
6e977e1181 | ||
|
|
62f6ca2b8a | ||
|
|
4e00c109e3 | ||
|
|
8f10a9c353 | ||
|
|
a3a35acc7e | ||
|
|
675eefa07e | ||
|
|
dbef6122e9 | ||
|
|
d150bcf622 | ||
|
|
451aab0116 | ||
|
|
3edf3583b1 | ||
|
|
ef2a7abad4 | ||
|
|
32f630ff5f | ||
|
|
109a0a0d49 | ||
|
|
4f01b37a2a | ||
|
|
cc6306136c | ||
|
|
419ace37f3 | ||
|
|
ccf24c363f | ||
|
|
b7a1ac6671 | ||
|
|
e54c0a8468 | ||
|
|
5f4cb32255 | ||
|
|
7b6cf39618 | ||
|
|
bf81de0c88 | ||
|
|
b36cad6929 | ||
|
|
b161bd6dfd | ||
|
|
538cfcbb77 | ||
|
|
a4105d2c0e | ||
|
|
553b341f5f | ||
|
|
e9e24b8cf1 | ||
|
|
1b693d0028 | ||
|
|
a4c3c07229 | ||
|
|
6b24748c80 | ||
|
|
8f2f8646eb | ||
|
|
e3ac438f5a | ||
|
|
b731628112 | ||
|
|
0dc56d9dcc | ||
|
|
b925b402e2 | ||
|
|
61d9653536 | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
a66203a391 | ||
|
|
fab61f614b | ||
|
|
6b67a11ad6 | ||
|
|
91f77d268c | ||
|
|
eb4d5187d8 | ||
|
|
ee4b02247c | ||
|
|
da8e1fe7e4 | ||
|
|
3db824c281 | ||
|
|
df2ecafd3f | ||
|
|
217652d28e | ||
|
|
f64c766dcd | ||
|
|
076fd85556 | ||
|
|
c7912ed827 | ||
|
|
e63f9d6993 | ||
|
|
d80ef3a677 | ||
|
|
852c3d831f | ||
|
|
ceb92ee7aa | ||
|
|
3a75026176 | ||
|
|
6a92b08244 | ||
|
|
38bc785ea9 | ||
|
|
a466fdca8f | ||
|
|
f9f49e3c78 | ||
|
|
61a30673c2 | ||
|
|
a48822ec00 | ||
|
|
b6c3d2b74a | ||
|
|
5006c2176c | ||
|
|
d3d3556ff6 | ||
|
|
6fa8dbe077 | ||
|
|
a57749ef60 | ||
|
|
b5c1d33e58 | ||
|
|
34a9f82865 | ||
|
|
18dc6cb962 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
c760208614 | ||
|
|
fad7aea58a | ||
|
|
b42eb1444c | ||
|
|
25a247dd3f | ||
|
|
7792017a02 | ||
|
|
0219e8d2f3 | ||
|
|
1d309a14a3 | ||
|
|
7df73ceaaf | ||
|
|
0dbb3d333f | ||
|
|
1419bec53d | ||
|
|
cf12723c89 | ||
|
|
4268f5466b | ||
|
|
b9f5a00d98 | ||
|
|
7d44dc99fb | ||
|
|
b20de1b44d | ||
|
|
366ee0f542 | ||
|
|
bed770248b | ||
|
|
020560d2b5 | ||
|
|
af7d305f00 | ||
|
|
427232cbc0 | ||
|
|
2899283c01 | ||
|
|
9cff769fbd | ||
|
|
23e33273f1 | ||
|
|
f191353cf4 | ||
|
|
66a094fc84 | ||
|
|
3681adc5ac | ||
|
|
4449faaa01 | ||
|
|
991ba162bd | ||
|
|
77d0f4d297 | ||
|
|
a834371d50 | ||
|
|
acda7d891a | ||
|
|
7434ec8fcd | ||
|
|
0699212665 | ||
|
|
f47de78b59 | ||
|
|
5fdc8039ec |
BIN
.github/workflows/logo.gif
vendored
Normal file
BIN
.github/workflows/logo.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 146 KiB |
4
.github/workflows/publish.yaml
vendored
4
.github/workflows/publish.yaml
vendored
@@ -20,9 +20,9 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel && pip install -r requirements.txt
|
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||||
- name: Build DiffSynth
|
- name: Build DiffSynth
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python -m build
|
||||||
- name: Publish package to PyPI
|
- name: Publish package to PyPI
|
||||||
run: |
|
run: |
|
||||||
pip install twine
|
pip install twine
|
||||||
|
|||||||
175
.gitignore
vendored
Normal file
175
.gitignore
vendored
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
/data
|
||||||
|
/models
|
||||||
|
/scripts
|
||||||
|
/diffusers
|
||||||
|
*.pkl
|
||||||
|
*.safetensors
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
*.pt
|
||||||
|
*.bin
|
||||||
|
*.DS_Store
|
||||||
|
*.msc
|
||||||
|
*.mv
|
||||||
|
log*.txt
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
953
README_zh.md
Normal file
953
README_zh.md
Normal file
@@ -0,0 +1,953 @@
|
|||||||
|
# DiffSynth-Studio
|
||||||
|
|
||||||
|
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||||
|
|
||||||
|
[](https://pypi.org/project/DiffSynth/)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||||
|
|
||||||
|
[Switch to English](./README.md)
|
||||||
|
|
||||||
|
## 简介
|
||||||
|
|
||||||
|
> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
|
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||||
|
|
||||||
|
DiffSynth 目前包括两个开源项目:
|
||||||
|
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
|
||||||
|
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
|
||||||
|
|
||||||
|
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
|
||||||
|
|
||||||
|
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
|
||||||
|
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
||||||
|
|
||||||
|
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
|
||||||
|
|
||||||
|
## 更新历史
|
||||||
|
|
||||||
|
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
||||||
|
|
||||||
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。
|
||||||
|
|
||||||
|
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
|
||||||
|
|
||||||
|
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
||||||
|
|
||||||
|
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
|
||||||
|
|
||||||
|
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
||||||
|
|
||||||
|
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
||||||
|
|
||||||
|
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
|
||||||
|
|
||||||
|
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
|
||||||
|
|
||||||
|
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
||||||
|
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
|
||||||
|
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
|
||||||
|
- 新模型支持
|
||||||
|
- Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
|
||||||
|
- FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
|
||||||
|
- 训练框架升级
|
||||||
|
- [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型),在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
|
||||||
|
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
|
||||||
|
- [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>更多</summary>
|
||||||
|
|
||||||
|
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
||||||
|
|
||||||
|
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
||||||
|
|
||||||
|
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
|
||||||
|
|
||||||
|
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
||||||
|
|
||||||
|
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||||
|
|
||||||
|
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
||||||
|
|
||||||
|
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
||||||
|
|
||||||
|
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
||||||
|
|
||||||
|
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
|
||||||
|
|
||||||
|
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
||||||
|
|
||||||
|
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
||||||
|
|
||||||
|
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
|
||||||
|
|
||||||
|
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
|
||||||
|
|
||||||
|
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
|
||||||
|
|
||||||
|
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||||
|
|
||||||
|
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||||
|
|
||||||
|
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
|
||||||
|
|
||||||
|
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
|
||||||
|
|
||||||
|
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
|
||||||
|
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||||
|
- Github 仓库: https://github.com/modelscope/Nexus-Gen
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||||
|
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||||
|
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||||
|
|
||||||
|
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
|
||||||
|
|
||||||
|
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
|
||||||
|
|
||||||
|
- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
|
||||||
|
|
||||||
|
- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||||
|
|
||||||
|
- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
|
||||||
|
|
||||||
|
- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
|
||||||
|
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||||
|
|
||||||
|
- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
|
||||||
|
- 论文: https://arxiv.org/abs/2412.12888
|
||||||
|
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
|
||||||
|
|
||||||
|
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
|
||||||
|
|
||||||
|
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
|
||||||
|
|
||||||
|
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
|
||||||
|
- 文本到视频
|
||||||
|
- 视频编辑
|
||||||
|
- 自我超分
|
||||||
|
- 视频插帧
|
||||||
|
|
||||||
|
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
|
||||||
|
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
|
||||||
|
|
||||||
|
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
|
||||||
|
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
|
||||||
|
- LoRA、ControlNet 和其他附加模型将很快推出。
|
||||||
|
|
||||||
|
- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
|
||||||
|
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
|
||||||
|
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
|
||||||
|
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
|
||||||
|
|
||||||
|
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
|
||||||
|
|
||||||
|
- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- 源代码已在此项目中发布。
|
||||||
|
- 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
|
||||||
|
|
||||||
|
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
|
||||||
|
|
||||||
|
- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
|
||||||
|
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
|
||||||
|
- 演示视频已在 Bilibili 上展示,包含三个任务:
|
||||||
|
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||||
|
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||||
|
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||||
|
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
|
||||||
|
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
|
||||||
|
|
||||||
|
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
|
||||||
|
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
|
||||||
|
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
|
||||||
|
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
|
||||||
|
- 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
|
||||||
|
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
|
||||||
|
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
|
||||||
|
|
||||||
|
- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
|
||||||
|
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
|
||||||
|
- 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
从源码安装(推荐):
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 基础框架
|
||||||
|
|
||||||
|
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>环境变量配置</summary>
|
||||||
|
|
||||||
|
> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
|
||||||
|
>
|
||||||
|
> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> import os
|
||||||
|
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 图像生成模型
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### Z-Image:[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = Flux2ImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
||||||
|
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||||
|
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||||
|
Qwen/Qwen-Image-->EliGen-Series;
|
||||||
|
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||||
|
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||||
|
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||||
|
Qwen/Qwen-Image-->Distill-Series;
|
||||||
|
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||||
|
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||||
|
Qwen/Qwen-Image-->ControlNet-Series;
|
||||||
|
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||||
|
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||||
|
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e4m3fn,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
||||||
|
)
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
image = pipe(prompt=prompt, seed=0)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||||
|
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||||
|
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||||
|
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||||
|
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||||
|
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||||
|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||||
|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||||
|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||||
|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||||
|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||||
|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||||
|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||||
|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||||
|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||||
|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 视频生成模型
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
|
|
||||||
|
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
|
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
||||||
|
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
# torch_dtype=torch.bfloat16,
|
||||||
|
# device="cuda",
|
||||||
|
# model_configs=[
|
||||||
|
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
# ],
|
||||||
|
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
# )
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
save_video(video, "video.mp4", fps=15, quality=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
Wan-Series-->Wan2.1-Series;
|
||||||
|
Wan-Series-->Wan2.2-Series;
|
||||||
|
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||||
|
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||||
|
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||||
|
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||||
|
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||||
|
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||||
|
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||||
|
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
||||||
|
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||||
|
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||||
|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||||
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||||
|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||||
|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||||
|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 创新成果
|
||||||
|
|
||||||
|
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>
|
||||||
|
|
||||||
|
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||||
|
](https://arxiv.org/abs/2602.03208)
|
||||||
|
- 代码样例:coming soon
|
||||||
|
|
||||||
|
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>VIRAL:基于DiT模型的类比视觉上下文推理</summary>
|
||||||
|
|
||||||
|
- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
|
||||||
|
](https://arxiv.org/abs/2602.03210)
|
||||||
|
- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|
||||||
|
|
||||||
|
|Example 1|Example 2|Query|Output|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
|
||||||
|
|
||||||
|
- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
|
||||||
|
](https://arxiv.org/abs/2508.02151)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|
||||||
|
|
||||||
|
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
||||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
|
||||||
|
|
||||||
|
- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
|
||||||
|
](https://arxiv.org/abs/2508.02107)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
|
||||||
|
|
||||||
|
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |||||
|
||||||
|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |||||
|
||||||
|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |||||
|
||||||
|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
|
||||||
|
|
||||||
|
- 详细页面:https://github.com/modelscope/Nexus-Gen
|
||||||
|
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||||
|
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||||
|
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>ArtAug: 图像生成模型的美学提升</summary>
|
||||||
|
|
||||||
|
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
|
||||||
|
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||||
|
|
||||||
|
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>EliGen: 精准的图像分区控制</summary>
|
||||||
|
|
||||||
|
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
|实体控制区域|生成图像|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>ExVideo: 视频生成模型的扩展训练</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
|
||||||
|
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
|
||||||
|
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
|
||||||
|
|
||||||
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>DiffSynth: 本项目的初代版本</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||||
|
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
|
||||||
|
|
||||||
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
|
</details>
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
import gradio as gr
|
|
||||||
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
|
||||||
import os, torch
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model_config": {
|
|
||||||
"Stable Diffusion": {
|
|
||||||
"model_folder": "models/stable_diffusion",
|
|
||||||
"pipeline_class": SDImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion 3": {
|
|
||||||
"model_folder": "models/stable_diffusion_3",
|
|
||||||
"pipeline_class": SD3ImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL Turbo": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"negative_prompt": "",
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
"num_inference_steps": 1,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Kolors": {
|
|
||||||
"model_folder": "models/kolors",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"HunyuanDiT": {
|
|
||||||
"model_folder": "models/HunyuanDiT",
|
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"max_num_painter_layers": 8,
|
|
||||||
"max_num_model_cache": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
|
||||||
if model_type is None:
|
|
||||||
return []
|
|
||||||
folder = config["model_config"][model_type]["model_folder"]
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
|
||||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type, model_path):
|
|
||||||
global model_dict
|
|
||||||
model_key = f"{model_type}:{model_path}"
|
|
||||||
if model_key in model_dict:
|
|
||||||
return model_dict[model_key]
|
|
||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
||||||
model_manager = ModelManager()
|
|
||||||
if model_type == "HunyuanDiT":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
|
||||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
|
||||||
])
|
|
||||||
elif model_type == "Kolors":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "text_encoder"),
|
|
||||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
])
|
|
||||||
elif model_type == "FLUX":
|
|
||||||
model_manager.torch_dtype = torch.bfloat16
|
|
||||||
file_list = [
|
|
||||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
|
||||||
os.path.join(model_path, "text_encoder_2"),
|
|
||||||
]
|
|
||||||
for file_name in os.listdir(model_path):
|
|
||||||
if file_name.endswith(".safetensors"):
|
|
||||||
file_list.append(os.path.join(model_path, file_name))
|
|
||||||
model_manager.load_models(file_list)
|
|
||||||
else:
|
|
||||||
model_manager.load_model(model_path)
|
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
|
||||||
key = next(iter(model_dict.keys()))
|
|
||||||
model_manager_to_release, _ = model_dict[key]
|
|
||||||
model_manager_to_release.to("cpu")
|
|
||||||
del model_dict[key]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model_dict[model_key] = model_manager, pipe
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
|
||||||
gr.Markdown("# DiffSynth-Studio Painter")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=382, min_width=100):
|
|
||||||
|
|
||||||
with gr.Accordion(label="Model"):
|
|
||||||
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
|
||||||
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
|
||||||
|
|
||||||
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
|
||||||
def model_type_to_model_path(model_type):
|
|
||||||
return gr.Dropdown(choices=load_model_list(model_type))
|
|
||||||
|
|
||||||
with gr.Accordion(label="Prompt"):
|
|
||||||
prompt = gr.Textbox(label="Prompt", lines=3)
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
||||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
|
||||||
|
|
||||||
with gr.Accordion(label="Image"):
|
|
||||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
|
||||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
||||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
||||||
with gr.Column():
|
|
||||||
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
|
||||||
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
|
||||||
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
|
||||||
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
|
||||||
triggers=model_path.change
|
|
||||||
)
|
|
||||||
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
|
||||||
load_model(model_type, model_path)
|
|
||||||
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
|
||||||
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
|
||||||
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
|
||||||
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
|
||||||
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
|
||||||
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Column(scale=618, min_width=100):
|
|
||||||
with gr.Accordion(label="Painter"):
|
|
||||||
enable_local_prompt_list = []
|
|
||||||
local_prompt_list = []
|
|
||||||
mask_scale_list = []
|
|
||||||
canvas_list = []
|
|
||||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
||||||
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
|
||||||
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
||||||
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
|
||||||
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
|
||||||
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
|
||||||
label="Painter", key=f"canvas_{painter_layer_id}")
|
|
||||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
|
||||||
def resize_canvas(height, width, canvas):
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
if h != height or width != w:
|
|
||||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
||||||
else:
|
|
||||||
return canvas
|
|
||||||
|
|
||||||
enable_local_prompt_list.append(enable_local_prompt)
|
|
||||||
local_prompt_list.append(local_prompt)
|
|
||||||
mask_scale_list.append(mask_scale)
|
|
||||||
canvas_list.append(canvas)
|
|
||||||
with gr.Accordion(label="Results"):
|
|
||||||
run_button = gr.Button(value="Generate", variant="primary")
|
|
||||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
||||||
with gr.Column():
|
|
||||||
output_to_input_button = gr.Button(value="Set as input image")
|
|
||||||
painter_background = gr.State(None)
|
|
||||||
input_background = gr.State(None)
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
|
||||||
outputs=[output_image],
|
|
||||||
triggers=run_button.click
|
|
||||||
)
|
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
|
||||||
input_params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"progress_bar_cmd": progress.tqdm,
|
|
||||||
}
|
|
||||||
if isinstance(pipe, FluxImagePipeline):
|
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
||||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
||||||
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
|
||||||
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
|
||||||
)
|
|
||||||
local_prompts, masks, mask_scales = [], [], []
|
|
||||||
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
|
||||||
):
|
|
||||||
if enable_local_prompt:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
||||||
mask_scales.append(mask_scale)
|
|
||||||
input_params.update({
|
|
||||||
"local_prompts": local_prompts,
|
|
||||||
"masks": masks,
|
|
||||||
"mask_scales": mask_scales,
|
|
||||||
})
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
image = pipe(**input_params)
|
|
||||||
return image
|
|
||||||
|
|
||||||
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
||||||
def send_output_to_painter_background(output_image, *canvas_list):
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = output_image.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
app.launch()
|
|
||||||
@@ -1,390 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
import gradio as gr
|
|
||||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
|
||||||
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
|
||||||
with open(example_json, 'r') as f:
|
|
||||||
examples = json.load(f)['examples']
|
|
||||||
|
|
||||||
for idx in range(len(examples)):
|
|
||||||
example_id = examples[idx]['example_id']
|
|
||||||
entity_prompts = examples[idx]['local_prompt_list']
|
|
||||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
|
||||||
|
|
||||||
def create_canvas_data(background, masks):
|
|
||||||
if background.shape[-1] == 3:
|
|
||||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
|
||||||
layers = []
|
|
||||||
for mask in masks:
|
|
||||||
if mask is not None:
|
|
||||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
|
||||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
|
||||||
layer[..., -1] = mask_single_channel
|
|
||||||
layers.append(layer)
|
|
||||||
else:
|
|
||||||
layers.append(np.zeros_like(background))
|
|
||||||
|
|
||||||
composite = background.copy()
|
|
||||||
for layer in layers:
|
|
||||||
if layer.size > 0:
|
|
||||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
|
||||||
return {
|
|
||||||
"background": background,
|
|
||||||
"layers": layers,
|
|
||||||
"composite": composite,
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_example(load_example_button):
|
|
||||||
example_idx = int(load_example_button.split()[-1]) - 1
|
|
||||||
example = examples[example_idx]
|
|
||||||
result = [
|
|
||||||
50,
|
|
||||||
example["global_prompt"],
|
|
||||||
example["negative_prompt"],
|
|
||||||
example["seed"],
|
|
||||||
*example["local_prompt_list"],
|
|
||||||
]
|
|
||||||
num_entities = len(example["local_prompt_list"])
|
|
||||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
|
||||||
masks = []
|
|
||||||
for mask in example["mask_lists"]:
|
|
||||||
mask_single_channel = np.array(mask.convert("L"))
|
|
||||||
masks.append(mask_single_channel)
|
|
||||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
|
||||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
|
||||||
masks.append(blank_mask)
|
|
||||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
|
||||||
canvas_data_list = []
|
|
||||||
for mask in masks:
|
|
||||||
canvas_data = create_canvas_data(background, [mask])
|
|
||||||
canvas_data_list.append(canvas_data)
|
|
||||||
result.extend(canvas_data_list)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
|
||||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
|
||||||
print(f'save to {save_dir}')
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
for i, mask in enumerate(masks):
|
|
||||||
save_path = os.path.join(save_dir, f'{i}.png')
|
|
||||||
mask.save(save_path)
|
|
||||||
sample = {
|
|
||||||
"global_prompt": global_prompt,
|
|
||||||
"mask_prompts": mask_prompts,
|
|
||||||
"seed": seed,
|
|
||||||
}
|
|
||||||
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
|
||||||
json.dump(sample, f, indent=4)
|
|
||||||
|
|
||||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
|
||||||
# Create a blank image for overlays
|
|
||||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
|
||||||
colors = [
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
]
|
|
||||||
# Generate random colors for each mask
|
|
||||||
if use_random_colors:
|
|
||||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
|
||||||
# Font settings
|
|
||||||
try:
|
|
||||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
|
||||||
except IOError:
|
|
||||||
font = ImageFont.load_default(font_size)
|
|
||||||
# Overlay each mask onto the overlay image
|
|
||||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
|
||||||
if mask is None:
|
|
||||||
continue
|
|
||||||
# Convert mask to RGBA mode
|
|
||||||
mask_rgba = mask.convert('RGBA')
|
|
||||||
mask_data = mask_rgba.getdata()
|
|
||||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
|
||||||
mask_rgba.putdata(new_data)
|
|
||||||
# Draw the mask prompt text on the mask
|
|
||||||
draw = ImageDraw.Draw(mask_rgba)
|
|
||||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
|
||||||
if mask_bbox is None:
|
|
||||||
continue
|
|
||||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
|
||||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
|
||||||
# Alpha composite the overlay with this mask
|
|
||||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
|
||||||
# Composite the overlay onto the original image
|
|
||||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
|
||||||
return result
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model_config": {
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 3.0,
|
|
||||||
"embedded_guidance": 3.5,
|
|
||||||
"num_inference_steps": 30,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"max_num_painter_layers": 8,
|
|
||||||
"max_num_model_cache": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
|
||||||
global model_dict
|
|
||||||
model_key = f"{model_type}:{model_path}"
|
|
||||||
if model_key in model_dict:
|
|
||||||
return model_dict[model_key]
|
|
||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
|
||||||
model_manager.load_lora(
|
|
||||||
download_customized_models(
|
|
||||||
model_id="DiffSynth-Studio/Eligen",
|
|
||||||
origin_file_path="model_bf16.safetensors",
|
|
||||||
local_dir="models/lora/entity_control",
|
|
||||||
),
|
|
||||||
lora_alpha=1,
|
|
||||||
)
|
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
model_dict[model_key] = model_manager, pipe
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
|
||||||
gr.Markdown(
|
|
||||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
|
||||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
|
||||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
|
||||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
|
||||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
|
||||||
main_interface = gr.Column(visible=False)
|
|
||||||
|
|
||||||
def initialize_model():
|
|
||||||
try:
|
|
||||||
load_model()
|
|
||||||
return {
|
|
||||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
|
||||||
main_interface: gr.update(visible=True),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Failed to load model with error: {e}')
|
|
||||||
return {
|
|
||||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
|
||||||
main_interface: gr.update(visible=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
|
||||||
|
|
||||||
with main_interface:
|
|
||||||
with gr.Row():
|
|
||||||
local_prompt_list = []
|
|
||||||
canvas_list = []
|
|
||||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
|
||||||
with gr.Column(scale=382, min_width=100):
|
|
||||||
model_type = gr.State('FLUX')
|
|
||||||
model_path = gr.State('FLUX.1-dev')
|
|
||||||
with gr.Accordion(label="Global prompt"):
|
|
||||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
|
||||||
with gr.Accordion(label="Inference Options", open=True):
|
|
||||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
|
||||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
|
||||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
||||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
|
||||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
||||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
||||||
with gr.Accordion(label="Inpaint Input Image", open=False):
|
|
||||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
|
||||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
|
||||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
|
||||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
|
||||||
def reset_input_image(input_image):
|
|
||||||
return None
|
|
||||||
|
|
||||||
with gr.Column(scale=618, min_width=100):
|
|
||||||
with gr.Accordion(label="Entity Painter"):
|
|
||||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
||||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
||||||
canvas = gr.ImageEditor(
|
|
||||||
canvas_size=(512, 512),
|
|
||||||
sources=None,
|
|
||||||
layers=False,
|
|
||||||
interactive=True,
|
|
||||||
image_mode="RGBA",
|
|
||||||
brush=gr.Brush(
|
|
||||||
default_size=50,
|
|
||||||
default_color="#000000",
|
|
||||||
colors=["#000000"],
|
|
||||||
),
|
|
||||||
label="Entity Mask Painter",
|
|
||||||
key=f"canvas_{painter_layer_id}",
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
)
|
|
||||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
|
||||||
def resize_canvas(height, width, canvas):
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
if h != height or width != w:
|
|
||||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
||||||
else:
|
|
||||||
return canvas
|
|
||||||
local_prompt_list.append(local_prompt)
|
|
||||||
canvas_list.append(canvas)
|
|
||||||
with gr.Accordion(label="Results"):
|
|
||||||
run_button = gr.Button(value="Generate", variant="primary")
|
|
||||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
||||||
with gr.Column():
|
|
||||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
|
||||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
|
||||||
real_output = gr.State(None)
|
|
||||||
mask_out = gr.State(None)
|
|
||||||
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
|
||||||
outputs=[output_image, real_output, mask_out],
|
|
||||||
triggers=run_button.click
|
|
||||||
)
|
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
|
||||||
input_params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"progress_bar_cmd": progress.tqdm,
|
|
||||||
}
|
|
||||||
if isinstance(pipe, FluxImagePipeline):
|
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
|
||||||
if input_image is not None:
|
|
||||||
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
|
||||||
input_params["enable_eligen_inpaint"] = True
|
|
||||||
|
|
||||||
local_prompt_list, canvas_list = (
|
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
||||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
||||||
)
|
|
||||||
local_prompts, masks = [], []
|
|
||||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
|
||||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
||||||
entity_masks = None if len(masks) == 0 else masks
|
|
||||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
|
||||||
input_params.update({
|
|
||||||
"eligen_entity_prompts": entity_prompts,
|
|
||||||
"eligen_entity_masks": entity_masks,
|
|
||||||
})
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
|
||||||
image = pipe(**input_params)
|
|
||||||
masks = [mask.resize(image.size) for mask in masks]
|
|
||||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
|
||||||
|
|
||||||
real_output = gr.State(image)
|
|
||||||
mask_out = gr.State(image_with_mask)
|
|
||||||
|
|
||||||
if return_with_mask:
|
|
||||||
return image_with_mask, real_output, mask_out
|
|
||||||
return image, real_output, mask_out
|
|
||||||
|
|
||||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
|
||||||
def send_input_to_painter_background(input_image, *canvas_list):
|
|
||||||
if input_image is None:
|
|
||||||
return tuple(canvas_list)
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = input_image.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
||||||
def send_output_to_painter_background(real_output, *canvas_list):
|
|
||||||
if real_output is None:
|
|
||||||
return tuple(canvas_list)
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = real_output.value.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
|
||||||
def show_output(return_with_mask, real_output, mask_out):
|
|
||||||
if return_with_mask:
|
|
||||||
return mask_out.value
|
|
||||||
else:
|
|
||||||
return real_output.value
|
|
||||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
|
||||||
def send_output_to_pipe_input(real_output):
|
|
||||||
return real_output.value
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("## Examples")
|
|
||||||
for i in range(0, len(examples), 2):
|
|
||||||
with gr.Row():
|
|
||||||
if i < len(examples):
|
|
||||||
example = examples[i]
|
|
||||||
with gr.Column():
|
|
||||||
example_image = gr.Image(
|
|
||||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
||||||
label=example["description"],
|
|
||||||
interactive=False,
|
|
||||||
width=1024,
|
|
||||||
height=512
|
|
||||||
)
|
|
||||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
||||||
load_example_button.click(
|
|
||||||
load_example,
|
|
||||||
inputs=[load_example_button],
|
|
||||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
||||||
)
|
|
||||||
|
|
||||||
if i + 1 < len(examples):
|
|
||||||
example = examples[i + 1]
|
|
||||||
with gr.Column():
|
|
||||||
example_image = gr.Image(
|
|
||||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
||||||
label=example["description"],
|
|
||||||
interactive=False,
|
|
||||||
width=1024,
|
|
||||||
height=512
|
|
||||||
)
|
|
||||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
||||||
load_example_button.click(
|
|
||||||
load_example,
|
|
||||||
inputs=[load_example_button],
|
|
||||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
||||||
)
|
|
||||||
app.config["show_progress"] = "hidden"
|
|
||||||
app.launch()
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# Set web page format
|
|
||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
# Diasble virtual VRAM on windows system
|
|
||||||
import torch
|
|
||||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
|
||||||
|
|
||||||
|
|
||||||
st.markdown("""
|
|
||||||
# DiffSynth Studio
|
|
||||||
|
|
||||||
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
|
||||||
|
|
||||||
Welcome to DiffSynth Studio.
|
|
||||||
""")
|
|
||||||
@@ -1,362 +0,0 @@
|
|||||||
import torch, os, io, json, time
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
from streamlit_drawable_canvas import st_canvas
|
|
||||||
from diffsynth.models import ModelManager
|
|
||||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
|
||||||
from diffsynth.data.video import crop_and_resize
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"Stable Diffusion": {
|
|
||||||
"model_folder": "models/stable_diffusion",
|
|
||||||
"pipeline_class": SDImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"Stable Diffusion 3": {
|
|
||||||
"model_folder": "models/stable_diffusion_3",
|
|
||||||
"pipeline_class": SD3ImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL Turbo": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"fixed_parameters": {
|
|
||||||
"negative_prompt": "",
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
"num_inference_steps": 1,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Kolors": {
|
|
||||||
"model_folder": "models/kolors",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"HunyuanDiT": {
|
|
||||||
"model_folder": "models/HunyuanDiT",
|
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
|
||||||
"fixed_parameters": {
|
|
||||||
"height": 1024,
|
|
||||||
"width": 1024,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"fixed_parameters": {
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
|
||||||
folder = config[model_type]["model_folder"]
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
|
||||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def release_model():
|
|
||||||
if "model_manager" in st.session_state:
|
|
||||||
st.session_state["model_manager"].to("cpu")
|
|
||||||
del st.session_state["loaded_model_path"]
|
|
||||||
del st.session_state["model_manager"]
|
|
||||||
del st.session_state["pipeline"]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type, model_path):
|
|
||||||
model_manager = ModelManager()
|
|
||||||
if model_type == "HunyuanDiT":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
|
||||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
|
||||||
])
|
|
||||||
elif model_type == "Kolors":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "text_encoder"),
|
|
||||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
])
|
|
||||||
elif model_type == "FLUX":
|
|
||||||
model_manager.torch_dtype = torch.bfloat16
|
|
||||||
file_list = [
|
|
||||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
|
||||||
os.path.join(model_path, "text_encoder_2"),
|
|
||||||
]
|
|
||||||
for file_name in os.listdir(model_path):
|
|
||||||
if file_name.endswith(".safetensors"):
|
|
||||||
file_list.append(os.path.join(model_path, file_name))
|
|
||||||
model_manager.load_models(file_list)
|
|
||||||
else:
|
|
||||||
model_manager.load_model(model_path)
|
|
||||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
st.session_state.loaded_model_path = model_path
|
|
||||||
st.session_state.model_manager = model_manager
|
|
||||||
st.session_state.pipeline = pipeline
|
|
||||||
return model_manager, pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def use_output_image_as_input(update=True):
|
|
||||||
# Search for input image
|
|
||||||
output_image_id = 0
|
|
||||||
selected_output_image = None
|
|
||||||
while True:
|
|
||||||
if f"use_output_as_input_{output_image_id}" not in st.session_state:
|
|
||||||
break
|
|
||||||
if st.session_state[f"use_output_as_input_{output_image_id}"]:
|
|
||||||
selected_output_image = st.session_state["output_images"][output_image_id]
|
|
||||||
break
|
|
||||||
output_image_id += 1
|
|
||||||
if update and selected_output_image is not None:
|
|
||||||
st.session_state["input_image"] = selected_output_image
|
|
||||||
return selected_output_image is not None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_stroke_to_image(stroke_image, image):
|
|
||||||
image = np.array(image.convert("RGB")).astype(np.float32)
|
|
||||||
height, width, _ = image.shape
|
|
||||||
|
|
||||||
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
|
|
||||||
weight = stroke_image[:, :, -1:] / 255
|
|
||||||
stroke_image = stroke_image[:, :, :-1]
|
|
||||||
|
|
||||||
image = stroke_image * weight + image * (1 - weight)
|
|
||||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
|
||||||
image = Image.fromarray(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
|
||||||
def image2bits(image):
|
|
||||||
image_byte = io.BytesIO()
|
|
||||||
image.save(image_byte, format="PNG")
|
|
||||||
image_byte = image_byte.getvalue()
|
|
||||||
return image_byte
|
|
||||||
|
|
||||||
|
|
||||||
def show_output_image(image):
|
|
||||||
st.image(image, use_column_width="always")
|
|
||||||
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
|
|
||||||
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
|
|
||||||
|
|
||||||
|
|
||||||
column_input, column_output = st.columns(2)
|
|
||||||
with st.sidebar:
|
|
||||||
# Select a model
|
|
||||||
with st.expander("Model", expanded=True):
|
|
||||||
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
|
|
||||||
fixed_parameters = config[model_type]["fixed_parameters"]
|
|
||||||
model_path_list = ["None"] + load_model_list(model_type)
|
|
||||||
model_path = st.selectbox("Model path", model_path_list)
|
|
||||||
|
|
||||||
# Load the model
|
|
||||||
if model_path == "None":
|
|
||||||
# No models are selected. Release VRAM.
|
|
||||||
st.markdown("No models are selected.")
|
|
||||||
release_model()
|
|
||||||
else:
|
|
||||||
# A model is selected.
|
|
||||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
|
||||||
if st.session_state.get("loaded_model_path", "") != model_path:
|
|
||||||
# The loaded model is not the selected model. Reload it.
|
|
||||||
st.markdown(f"Loading model at {model_path}.")
|
|
||||||
st.markdown("Please wait a moment...")
|
|
||||||
release_model()
|
|
||||||
model_manager, pipeline = load_model(model_type, model_path)
|
|
||||||
st.markdown("Done.")
|
|
||||||
else:
|
|
||||||
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
|
||||||
st.markdown(f"Loading model at {model_path}.")
|
|
||||||
st.markdown("Please wait a moment...")
|
|
||||||
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
|
||||||
st.markdown("Done.")
|
|
||||||
|
|
||||||
# Show parameters
|
|
||||||
with st.expander("Prompt", expanded=True):
|
|
||||||
prompt = st.text_area("Positive prompt")
|
|
||||||
if "negative_prompt" in fixed_parameters:
|
|
||||||
negative_prompt = fixed_parameters["negative_prompt"]
|
|
||||||
else:
|
|
||||||
negative_prompt = st.text_area("Negative prompt")
|
|
||||||
if "cfg_scale" in fixed_parameters:
|
|
||||||
cfg_scale = fixed_parameters["cfg_scale"]
|
|
||||||
else:
|
|
||||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
|
|
||||||
with st.expander("Image", expanded=True):
|
|
||||||
if "num_inference_steps" in fixed_parameters:
|
|
||||||
num_inference_steps = fixed_parameters["num_inference_steps"]
|
|
||||||
else:
|
|
||||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
|
|
||||||
if "height" in fixed_parameters:
|
|
||||||
height = fixed_parameters["height"]
|
|
||||||
else:
|
|
||||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
|
|
||||||
if "width" in fixed_parameters:
|
|
||||||
width = fixed_parameters["width"]
|
|
||||||
else:
|
|
||||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
|
|
||||||
num_images = st.number_input("Number of images", value=2)
|
|
||||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
|
||||||
if use_fixed_seed:
|
|
||||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
|
||||||
|
|
||||||
# Other fixed parameters
|
|
||||||
denoising_strength = 1.0
|
|
||||||
repetition = 1
|
|
||||||
|
|
||||||
|
|
||||||
# Show input image
|
|
||||||
with column_input:
|
|
||||||
with st.expander("Input image (Optional)", expanded=True):
|
|
||||||
with st.container(border=True):
|
|
||||||
column_white_board, column_upload_image = st.columns([1, 2])
|
|
||||||
with column_white_board:
|
|
||||||
create_white_board = st.button("Create white board")
|
|
||||||
delete_input_image = st.button("Delete input image")
|
|
||||||
with column_upload_image:
|
|
||||||
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
|
|
||||||
|
|
||||||
if upload_image is not None:
|
|
||||||
st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
|
|
||||||
elif create_white_board:
|
|
||||||
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
|
|
||||||
else:
|
|
||||||
use_output_image_as_input()
|
|
||||||
|
|
||||||
if delete_input_image and "input_image" in st.session_state:
|
|
||||||
del st.session_state.input_image
|
|
||||||
if delete_input_image and "upload_image" in st.session_state:
|
|
||||||
del st.session_state.upload_image
|
|
||||||
|
|
||||||
input_image = st.session_state.get("input_image", None)
|
|
||||||
if input_image is not None:
|
|
||||||
with st.container(border=True):
|
|
||||||
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
|
|
||||||
with column_drawing_mode:
|
|
||||||
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
|
|
||||||
with column_color_1:
|
|
||||||
stroke_color = st.color_picker("Stroke color")
|
|
||||||
with column_color_2:
|
|
||||||
fill_color = st.color_picker("Fill color")
|
|
||||||
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
|
|
||||||
with st.container(border=True):
|
|
||||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
|
|
||||||
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
|
|
||||||
with st.container(border=True):
|
|
||||||
input_width, input_height = input_image.size
|
|
||||||
canvas_result = st_canvas(
|
|
||||||
fill_color=fill_color,
|
|
||||||
stroke_width=stroke_width,
|
|
||||||
stroke_color=stroke_color,
|
|
||||||
background_color="rgba(255, 255, 255, 0)",
|
|
||||||
background_image=input_image,
|
|
||||||
update_streamlit=True,
|
|
||||||
height=int(512 / input_width * input_height),
|
|
||||||
width=512,
|
|
||||||
drawing_mode=drawing_mode,
|
|
||||||
key="canvas"
|
|
||||||
)
|
|
||||||
|
|
||||||
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
|
||||||
local_prompts, masks, mask_scales = [], [], []
|
|
||||||
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
|
||||||
painter_layers_json_data = []
|
|
||||||
for painter_tab_id in range(num_painter_layer):
|
|
||||||
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
|
||||||
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
|
||||||
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
|
||||||
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
|
||||||
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
|
||||||
canvas_result_local = st_canvas(
|
|
||||||
fill_color="#000000",
|
|
||||||
stroke_width=stroke_width,
|
|
||||||
stroke_color="#000000",
|
|
||||||
background_color="rgba(255, 255, 255, 0)",
|
|
||||||
background_image=white_board,
|
|
||||||
update_streamlit=True,
|
|
||||||
height=512,
|
|
||||||
width=512,
|
|
||||||
drawing_mode="freedraw",
|
|
||||||
key=f"canvas_{painter_tab_id}"
|
|
||||||
)
|
|
||||||
if canvas_result_local.json_data is not None:
|
|
||||||
painter_layers_json_data.append(canvas_result_local.json_data.copy())
|
|
||||||
painter_layers_json_data[-1]["prompt"] = local_prompt
|
|
||||||
if enable_local_prompt:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
if canvas_result_local.image_data is not None:
|
|
||||||
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
|
||||||
else:
|
|
||||||
mask = white_board
|
|
||||||
mask = Image.fromarray(255 - np.array(mask))
|
|
||||||
masks.append(mask)
|
|
||||||
mask_scales.append(mask_scale)
|
|
||||||
save_painter_layers = st.button("Save painter layers")
|
|
||||||
if save_painter_layers:
|
|
||||||
os.makedirs("data/painter_layers", exist_ok=True)
|
|
||||||
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
|
|
||||||
with open(json_file_path, "w") as f:
|
|
||||||
json.dump(painter_layers_json_data, f, indent=4)
|
|
||||||
st.markdown(f"Painter layers are saved in {json_file_path}.")
|
|
||||||
|
|
||||||
|
|
||||||
with column_output:
|
|
||||||
run_button = st.button("Generate image", type="primary")
|
|
||||||
auto_update = st.checkbox("Auto update", value=False)
|
|
||||||
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
|
|
||||||
image_columns = st.columns(num_image_columns)
|
|
||||||
|
|
||||||
# Run
|
|
||||||
if (run_button or auto_update) and model_path != "None":
|
|
||||||
|
|
||||||
if input_image is not None:
|
|
||||||
input_image = input_image.resize((width, height))
|
|
||||||
if canvas_result.image_data is not None:
|
|
||||||
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
|
|
||||||
|
|
||||||
output_images = []
|
|
||||||
for image_id in range(num_images * repetition):
|
|
||||||
if use_fixed_seed:
|
|
||||||
torch.manual_seed(seed + image_id)
|
|
||||||
else:
|
|
||||||
torch.manual_seed(np.random.randint(0, 10**9))
|
|
||||||
if image_id >= num_images:
|
|
||||||
input_image = output_images[image_id - num_images]
|
|
||||||
with image_columns[image_id % num_image_columns]:
|
|
||||||
progress_bar_st = st.progress(0.0)
|
|
||||||
image = pipeline(
|
|
||||||
prompt, negative_prompt=negative_prompt,
|
|
||||||
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
|
||||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
|
||||||
height=height, width=width,
|
|
||||||
input_image=input_image, denoising_strength=denoising_strength,
|
|
||||||
progress_bar_st=progress_bar_st
|
|
||||||
)
|
|
||||||
output_images.append(image)
|
|
||||||
progress_bar_st.progress(1.0)
|
|
||||||
show_output_image(image)
|
|
||||||
st.session_state["output_images"] = output_images
|
|
||||||
|
|
||||||
elif "output_images" in st.session_state:
|
|
||||||
for image_id in range(len(st.session_state.output_images)):
|
|
||||||
with image_columns[image_id % num_image_columns]:
|
|
||||||
image = st.session_state.output_images[image_id]
|
|
||||||
progress_bar = st.progress(1.0)
|
|
||||||
show_output_image(image)
|
|
||||||
if "upload_image" in st.session_state and use_output_image_as_input(update=False):
|
|
||||||
st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")
|
|
||||||
@@ -1,197 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
from diffsynth import SDVideoPipelineRunner
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(folder):
|
|
||||||
file_list = os.listdir(folder)
|
|
||||||
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def match_processor_id(model_name, supported_processor_id_list):
|
|
||||||
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
|
|
||||||
for processor_id in sorted_processor_id:
|
|
||||||
if processor_id in model_name:
|
|
||||||
return supported_processor_id_list.index(processor_id) + 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"models": {
|
|
||||||
"model_list": [],
|
|
||||||
"textual_inversion_folder": "models/textual_inversion",
|
|
||||||
"device": "cuda",
|
|
||||||
"lora_alphas": [],
|
|
||||||
"controlnet_units": []
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"input_frames": None,
|
|
||||||
"controlnet_frames": [],
|
|
||||||
"output_folder": "output",
|
|
||||||
"fps": 60
|
|
||||||
},
|
|
||||||
"pipeline": {
|
|
||||||
"seed": 0,
|
|
||||||
"pipeline_inputs": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
with st.expander("Model", expanded=True):
|
|
||||||
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
|
|
||||||
if stable_diffusion_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
|
|
||||||
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
|
||||||
if animatediff_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
|
||||||
column_lora, column_lora_alpha = st.columns([2, 1])
|
|
||||||
with column_lora:
|
|
||||||
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
|
|
||||||
with column_lora_alpha:
|
|
||||||
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
|
|
||||||
if sd_lora_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
|
|
||||||
config["models"]["lora_alphas"].append(lora_alpha)
|
|
||||||
|
|
||||||
|
|
||||||
with st.expander("Data", expanded=True):
|
|
||||||
with st.container(border=True):
|
|
||||||
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
|
|
||||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
|
||||||
with column_height:
|
|
||||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
|
||||||
with column_width:
|
|
||||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
|
||||||
with column_start_frame_index:
|
|
||||||
start_frame_id = st.number_input("Start Frame id", value=0)
|
|
||||||
with column_end_frame_index:
|
|
||||||
end_frame_id = st.number_input("End Frame id", value=16)
|
|
||||||
if input_video != "":
|
|
||||||
config["data"]["input_frames"] = {
|
|
||||||
"video_file": input_video,
|
|
||||||
"image_folder": None,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"start_frame_id": start_frame_id,
|
|
||||||
"end_frame_id": end_frame_id
|
|
||||||
}
|
|
||||||
with st.container(border=True):
|
|
||||||
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
|
|
||||||
fps = st.number_input("FPS", value=60)
|
|
||||||
config["data"]["output_folder"] = output_video
|
|
||||||
config["data"]["fps"] = fps
|
|
||||||
|
|
||||||
|
|
||||||
with st.expander("ControlNet Units", expanded=True):
|
|
||||||
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
|
|
||||||
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
|
|
||||||
for controlnet_id in range(len(controlnet_units)):
|
|
||||||
with controlnet_units[controlnet_id]:
|
|
||||||
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
|
|
||||||
key=f"controlnet_ckpt_{controlnet_id}")
|
|
||||||
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
|
|
||||||
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
|
|
||||||
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
|
|
||||||
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
|
|
||||||
disabled=controlnet_ckpt == "None",
|
|
||||||
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
|
|
||||||
if not use_input_video_as_controlnet_input:
|
|
||||||
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
|
|
||||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
|
||||||
with column_height:
|
|
||||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
|
|
||||||
with column_width:
|
|
||||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
|
|
||||||
with column_start_frame_index:
|
|
||||||
start_frame_id = st.number_input("Start Frame id", value=0,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
|
|
||||||
with column_end_frame_index:
|
|
||||||
end_frame_id = st.number_input("End Frame id", value=16,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
|
|
||||||
if input_video != "":
|
|
||||||
config["data"]["input_video"] = {
|
|
||||||
"video_file": input_video,
|
|
||||||
"image_folder": None,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"start_frame_id": start_frame_id,
|
|
||||||
"end_frame_id": end_frame_id
|
|
||||||
}
|
|
||||||
if controlnet_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
|
|
||||||
config["models"]["controlnet_units"].append({
|
|
||||||
"processor_id": processor_id,
|
|
||||||
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
|
|
||||||
"scale": controlnet_scale,
|
|
||||||
})
|
|
||||||
if use_input_video_as_controlnet_input:
|
|
||||||
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
|
|
||||||
else:
|
|
||||||
config["data"]["controlnet_frames"].append({
|
|
||||||
"video_file": input_video,
|
|
||||||
"image_folder": None,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"start_frame_id": start_frame_id,
|
|
||||||
"end_frame_id": end_frame_id
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
with st.container(border=True):
|
|
||||||
with st.expander("Seed", expanded=True):
|
|
||||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
|
||||||
if use_fixed_seed:
|
|
||||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
|
||||||
else:
|
|
||||||
seed = np.random.randint(0, 10**9)
|
|
||||||
with st.expander("Textual Guidance", expanded=True):
|
|
||||||
prompt = st.text_area("Positive prompt")
|
|
||||||
negative_prompt = st.text_area("Negative prompt")
|
|
||||||
column_cfg_scale, column_clip_skip = st.columns(2)
|
|
||||||
with column_cfg_scale:
|
|
||||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
|
|
||||||
with column_clip_skip:
|
|
||||||
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
|
|
||||||
with st.expander("Denoising", expanded=True):
|
|
||||||
column_num_inference_steps, column_denoising_strength = st.columns(2)
|
|
||||||
with column_num_inference_steps:
|
|
||||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
|
|
||||||
with column_denoising_strength:
|
|
||||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
|
|
||||||
with st.expander("Efficiency", expanded=False):
|
|
||||||
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
|
|
||||||
animatediff_stride = st.slider("Animatediff stride",
|
|
||||||
min_value=1,
|
|
||||||
max_value=max(2, animatediff_batch_size),
|
|
||||||
value=max(1, animatediff_batch_size // 2),
|
|
||||||
step=1)
|
|
||||||
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
|
|
||||||
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
|
|
||||||
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
|
|
||||||
config["pipeline"]["seed"] = seed
|
|
||||||
config["pipeline"]["pipeline_inputs"] = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"clip_skip": clip_skip,
|
|
||||||
"denoising_strength": denoising_strength,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"animatediff_batch_size": animatediff_batch_size,
|
|
||||||
"animatediff_stride": animatediff_stride,
|
|
||||||
"unet_batch_size": unet_batch_size,
|
|
||||||
"controlnet_batch_size": controlnet_batch_size,
|
|
||||||
"cross_frame_attention": cross_frame_attention,
|
|
||||||
}
|
|
||||||
|
|
||||||
run_button = st.button("☢️Run☢️", type="primary")
|
|
||||||
if run_button:
|
|
||||||
SDVideoPipelineRunner(in_streamlit=True).run(config)
|
|
||||||
@@ -1,6 +1 @@
|
|||||||
from .data import *
|
from .core import *
|
||||||
from .models import *
|
|
||||||
from .prompters import *
|
|
||||||
from .schedulers import *
|
|
||||||
from .pipelines import *
|
|
||||||
from .controlnets import *
|
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .model_configs import MODEL_CONFIGS
|
||||||
|
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
|
||||||
|
|||||||
@@ -1,736 +0,0 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
|
|
||||||
from ..models.sd_text_encoder import SDTextEncoder
|
|
||||||
from ..models.sd_unet import SDUNet
|
|
||||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
|
||||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
|
||||||
|
|
||||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from ..models.sdxl_unet import SDXLUNet
|
|
||||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from ..models.sd3_dit import SD3DiT
|
|
||||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_controlnet import SDControlNet
|
|
||||||
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
|
||||||
|
|
||||||
from ..models.sd_motion import SDMotionModel
|
|
||||||
from ..models.sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from ..models.svd_image_encoder import SVDImageEncoder
|
|
||||||
from ..models.svd_unet import SVDUNet
|
|
||||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from ..models.hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
from ..models.flux_dit import FluxDiT
|
|
||||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
|
||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
|
||||||
|
|
||||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
|
||||||
from ..models.cog_dit import CogDiT
|
|
||||||
|
|
||||||
from ..models.omnigen import OmniGenTransformer
|
|
||||||
|
|
||||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
|
||||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
|
||||||
|
|
||||||
from ..extensions.RIFE import IFNet
|
|
||||||
from ..extensions.ESRGAN import RRDBNet
|
|
||||||
|
|
||||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
|
||||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
|
||||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
|
||||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
|
||||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
|
||||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
|
||||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
|
||||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
|
||||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
|
||||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
|
||||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
|
||||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
|
||||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
|
||||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
|
||||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
|
||||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
|
||||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
|
||||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
|
||||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
|
||||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
|
||||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
|
||||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
|
||||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
|
||||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
|
||||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
|
||||||
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
|
||||||
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
|
||||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
|
||||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
|
||||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
|
||||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
|
||||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
|
||||||
]
|
|
||||||
huggingface_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
|
||||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
|
||||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
|
||||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
|
||||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
|
||||||
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
|
||||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
|
||||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
|
||||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
|
||||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
|
|
||||||
]
|
|
||||||
patch_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
|
||||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
|
||||||
]
|
|
||||||
|
|
||||||
preset_models_on_huggingface = {
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_union_sdxl_promax": [
|
|
||||||
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
|
||||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
|
|
||||||
# Qwen Prompt
|
|
||||||
"QwenPrompt": [
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
],
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": [
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
# Omost prompt
|
|
||||||
"OmostPrompt":[
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": [
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
# FLUX
|
|
||||||
"FLUX.1-dev": [
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
],
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
|
||||||
"file_list": [
|
|
||||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
|
||||||
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# CogVideo
|
|
||||||
"CogVideoX-5B": [
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3.5
|
|
||||||
"StableDiffusion3.5-large": [
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
preset_models_on_modelscope = {
|
|
||||||
# Hunyuan DiT
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
# Stable Video Diffusion
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# ExVideo
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"AingDiffusion_v12": [
|
|
||||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"Flat2DAnimerge_v45Sharp": [
|
|
||||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
|
||||||
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_union_sdxl_promax": [
|
|
||||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"Annotators:Depth": [
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Softedge": [
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Lineart": [
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Normal": [
|
|
||||||
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Openpose": [
|
|
||||||
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Qwen Prompt
|
|
||||||
"QwenPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/QwenPrompt/qwen2-1.5b-instruct",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Omost prompt
|
|
||||||
"OmostPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": {
|
|
||||||
"file_list": [
|
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/translator/opus-mt-zh-en",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": {
|
|
||||||
"file_list": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/kolors/Kolors/text_encoder",
|
|
||||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
# FLUX
|
|
||||||
"FLUX.1-dev": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
||||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"FLUX.1-schnell": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
||||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
|
||||||
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
|
||||||
],
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
|
||||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
|
||||||
],
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
|
||||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
|
||||||
],
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
|
||||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
|
||||||
],
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
|
||||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
|
||||||
],
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
|
||||||
"file_list": [
|
|
||||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
|
||||||
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# ESRGAN
|
|
||||||
"ESRGAN_x4": [
|
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Omnigen
|
|
||||||
"OmniGen-v1": {
|
|
||||||
"file_list": [
|
|
||||||
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
|
||||||
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/OmniGen/OmniGen-v1/model.safetensors",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
# CogVideo
|
|
||||||
"CogVideoX-5B": {
|
|
||||||
"file_list": [
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/CogVideo/CogVideoX-5b/text_encoder",
|
|
||||||
"models/CogVideo/CogVideoX-5b/transformer",
|
|
||||||
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Stable Diffusion 3.5
|
|
||||||
"StableDiffusion3.5-large": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3.5-medium": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3.5-large-turbo": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"HunyuanVideo":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"HunyuanVideo-fp8":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
Preset_model_id: TypeAlias = Literal[
|
|
||||||
"HunyuanDiT",
|
|
||||||
"stable-video-diffusion-img2vid-xt",
|
|
||||||
"ExVideo-SVD-128f-v1",
|
|
||||||
"ExVideo-CogVideoX-LoRA-129f-v1",
|
|
||||||
"StableDiffusion_v15",
|
|
||||||
"DreamShaper_8",
|
|
||||||
"AingDiffusion_v12",
|
|
||||||
"Flat2DAnimerge_v45Sharp",
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
|
||||||
"StableDiffusionXL_v1",
|
|
||||||
"BluePencilXL_v200",
|
|
||||||
"StableDiffusionXL_Turbo",
|
|
||||||
"ControlNet_v11f1p_sd15_depth",
|
|
||||||
"ControlNet_v11p_sd15_softedge",
|
|
||||||
"ControlNet_v11f1e_sd15_tile",
|
|
||||||
"ControlNet_v11p_sd15_lineart",
|
|
||||||
"AnimateDiff_v2",
|
|
||||||
"AnimateDiff_xl_beta",
|
|
||||||
"RIFE",
|
|
||||||
"BeautifulPrompt",
|
|
||||||
"opus-mt-zh-en",
|
|
||||||
"IP-Adapter-SD",
|
|
||||||
"IP-Adapter-SDXL",
|
|
||||||
"StableDiffusion3",
|
|
||||||
"StableDiffusion3_without_T5",
|
|
||||||
"Kolors",
|
|
||||||
"SDXL-vae-fp16-fix",
|
|
||||||
"ControlNet_union_sdxl_promax",
|
|
||||||
"FLUX.1-dev",
|
|
||||||
"FLUX.1-schnell",
|
|
||||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Depth",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
|
||||||
"QwenPrompt",
|
|
||||||
"OmostPrompt",
|
|
||||||
"ESRGAN_x4",
|
|
||||||
"RIFE",
|
|
||||||
"OmniGen-v1",
|
|
||||||
"CogVideoX-5B",
|
|
||||||
"Annotators:Depth",
|
|
||||||
"Annotators:Softedge",
|
|
||||||
"Annotators:Lineart",
|
|
||||||
"Annotators:Normal",
|
|
||||||
"Annotators:Openpose",
|
|
||||||
"StableDiffusion3.5-large",
|
|
||||||
"StableDiffusion3.5-medium",
|
|
||||||
"HunyuanVideo",
|
|
||||||
"HunyuanVideo-fp8",
|
|
||||||
]
|
|
||||||
722
diffsynth/configs/model_configs.py
Normal file
722
diffsynth/configs/model_configs.py
Normal file
@@ -0,0 +1,722 @@
|
|||||||
|
qwen_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
||||||
|
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
||||||
|
"model_name": "qwen_image_dit",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8004730443f55db63092006dd9f7110e",
|
||||||
|
"model_name": "qwen_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
||||||
|
"model_name": "qwen_image_vae",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
||||||
|
"model_name": "qwen_image_blockwise_controlnet",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
||||||
|
"model_name": "qwen_image_blockwise_controlnet",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
|
"extra_kwargs": {"additional_in_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
||||||
|
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
||||||
|
"model_name": "siglip2_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
||||||
|
"model_hash": "5722b5c873720009de96422993b15682",
|
||||||
|
"model_name": "dinov3_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
||||||
|
"model_name": "qwen_image_image2lora_coarse",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
||||||
|
"model_name": "qwen_image_image2lora_fine",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
||||||
|
"model_name": "qwen_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
|
||||||
|
"model_name": "qwen_image_dit",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||||
|
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
|
||||||
|
"model_name": "qwen_image_vae",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||||
|
"extra_kwargs": {"image_channels": 4}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
wan_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
||||||
|
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
||||||
|
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
||||||
|
"model_name": "wan_video_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
||||||
|
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
||||||
|
"model_name": "wan_video_vae",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||||
|
"model_name": "wan_video_vap",
|
||||||
|
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
||||||
|
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
||||||
|
"model_name": "wan_video_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
||||||
|
"model_name": "wan_video_motion_controller",
|
||||||
|
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||||
|
"model_name": "wan_video_vace",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||||
|
"model_name": "wan_video_vace",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||||
|
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||||
|
"model_name": "wan_video_animate_adapter",
|
||||||
|
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
||||||
|
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
||||||
|
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
||||||
|
"model_name": "wan_video_vae",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
||||||
|
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
||||||
|
"model_name": "wans2v_audio_encoder",
|
||||||
|
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
flux_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||||
|
"model_hash": "a29710fea6dddb0314663ee823598e50",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Supported due to historical reasons.
|
||||||
|
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
|
"model_name": "flux_text_encoder_clip",
|
||||||
|
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
|
||||||
|
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
|
||||||
|
"model_name": "flux_text_encoder_t5",
|
||||||
|
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||||
|
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||||
|
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
|
||||||
|
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||||
|
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
||||||
|
"model_name": "flux_value_controller",
|
||||||
|
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "52357cb26250681367488a8954c271e8",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_single_blocks": 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
||||||
|
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
||||||
|
"model_name": "infiniteyou_image_projector",
|
||||||
|
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
||||||
|
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
||||||
|
"model_name": "flux_lora_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
||||||
|
"model_name": "flux_lora_patcher",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
||||||
|
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
||||||
|
"model_name": "nexus_gen_llm",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "nexus_gen_editing_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "nexus_gen_generation_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
||||||
|
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
||||||
|
"model_name": "flux_ipadapter",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
||||||
|
"model_name": "siglip_vision_model",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
|
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||||
|
"model_name": "step1x_connector",
|
||||||
|
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
|
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
||||||
|
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
flux2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
||||||
|
"model_name": "flux2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "c54288e3ee12ca215898840682337b95",
|
||||||
|
"model_name": "flux2_vae",
|
||||||
|
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "8B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
z_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
"extra_kwargs": {"siglip_feat_dim": 1152},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
||||||
|
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
||||||
|
"model_name": "siglip_vision_model_428m",
|
||||||
|
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||||
|
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||||
|
"model_name": "z_image_controlnet",
|
||||||
|
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ???
|
||||||
|
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||||
|
"model_name": "z_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 128},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "1392adecee344136041e70553f875f31",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
|
ltx2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
|
||||||
|
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
||||||
|
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
||||||
|
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
|
||||||
|
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||||
|
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
||||||
|
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
||||||
|
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||||
|
"model_name": "ltx2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||||
|
"model_name": "ltx2_latent_upsampler",
|
||||||
|
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||||
246
diffsynth/configs/vram_management_module_maps.py
Normal file
246
diffsynth/configs/vram_management_module_maps.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
flux_general_vram_config = {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
}
|
||||||
|
|
||||||
|
VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||||
|
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||||
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
|
||||||
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_dit.WanModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
|
||||||
|
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_mot.MotWanModel": {
|
||||||
|
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vace.VaceWanModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vae.WanVideoVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_dit.FluxDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
|
||||||
|
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_dit.Flux2DiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_vae.Flux2VAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_dit.ZImageDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
||||||
|
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
|
||||||
from .processors import Annotator
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from .processors import Processor_id
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetConfigUnit:
|
|
||||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
|
||||||
self.processor_id = processor_id
|
|
||||||
self.model_path = model_path
|
|
||||||
self.scale = scale
|
|
||||||
self.skip_processor = skip_processor
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetUnit:
|
|
||||||
def __init__(self, processor, model, scale=1.0):
|
|
||||||
self.processor = processor
|
|
||||||
self.model = model
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
|
|
||||||
class MultiControlNetManager:
|
|
||||||
def __init__(self, controlnet_units=[]):
|
|
||||||
self.processors = [unit.processor for unit in controlnet_units]
|
|
||||||
self.models = [unit.model for unit in controlnet_units]
|
|
||||||
self.scales = [unit.scale for unit in controlnet_units]
|
|
||||||
|
|
||||||
def cpu(self):
|
|
||||||
for model in self.models:
|
|
||||||
model.cpu()
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.models:
|
|
||||||
model.to(device)
|
|
||||||
for processor in self.processors:
|
|
||||||
processor.to(device)
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
|
||||||
if processor_id is None:
|
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
|
||||||
else:
|
|
||||||
processed_image = [self.processors[processor_id](image)]
|
|
||||||
processed_image = torch.concat([
|
|
||||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
|
||||||
for image_ in processed_image
|
|
||||||
], dim=0)
|
|
||||||
return processed_image
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sample, timestep, encoder_hidden_states, conditionings,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
|
||||||
):
|
|
||||||
res_stack = None
|
|
||||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
||||||
res_stack_ = model(
|
|
||||||
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
processor_id=processor.processor_id
|
|
||||||
)
|
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
return res_stack
|
|
||||||
|
|
||||||
|
|
||||||
class FluxMultiControlNetManager(MultiControlNetManager):
|
|
||||||
def __init__(self, controlnet_units=[]):
|
|
||||||
super().__init__(controlnet_units=controlnet_units)
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
|
||||||
if processor_id is None:
|
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
|
||||||
else:
|
|
||||||
processed_image = [self.processors[processor_id](image)]
|
|
||||||
return processed_image
|
|
||||||
|
|
||||||
def __call__(self, conditionings, **kwargs):
|
|
||||||
res_stack, single_res_stack = None, None
|
|
||||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
||||||
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
|
||||||
single_res_stack_ = [res * scale for res in single_res_stack_]
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
single_res_stack = single_res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
|
||||||
return res_stack, single_res_stack
|
|
||||||
6
diffsynth/core/__init__.py
Normal file
6
diffsynth/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .attention import *
|
||||||
|
from .data import *
|
||||||
|
from .gradient import *
|
||||||
|
from .loader import *
|
||||||
|
from .vram import *
|
||||||
|
from .device import *
|
||||||
1
diffsynth/core/attention/__init__.py
Normal file
1
diffsynth/core/attention/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .attention import attention_forward
|
||||||
121
diffsynth/core/attention/attention.py
Normal file
121
diffsynth/core/attention/attention.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import torch, os
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
FLASH_ATTN_2_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
SAGE_ATTN_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers.ops as xops
|
||||||
|
XFORMERS_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_attention_priority():
|
||||||
|
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
||||||
|
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
||||||
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
|
return "flash_attention_3"
|
||||||
|
elif FLASH_ATTN_2_AVAILABLE:
|
||||||
|
return "flash_attention_2"
|
||||||
|
elif SAGE_ATTN_AVAILABLE:
|
||||||
|
return "sage_attention"
|
||||||
|
elif XFORMERS_AVAILABLE:
|
||||||
|
return "xformers"
|
||||||
|
else:
|
||||||
|
return "torch"
|
||||||
|
|
||||||
|
|
||||||
|
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
||||||
|
dims = {} if dims is None else dims
|
||||||
|
if q_pattern != required_in_pattern:
|
||||||
|
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
if k_pattern != required_in_pattern:
|
||||||
|
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
if v_pattern != required_in_pattern:
|
||||||
|
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
||||||
|
dims = {} if dims is None else dims
|
||||||
|
if out_pattern != required_out_pattern:
|
||||||
|
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = out[0]
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = sageattn(q, k, v, sm_scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
||||||
|
if compatibility_mode or (attn_mask is not None):
|
||||||
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
||||||
|
else:
|
||||||
|
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
||||||
|
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
||||||
|
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
||||||
|
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "xformers":
|
||||||
|
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
else:
|
||||||
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
1
diffsynth/core/data/__init__.py
Normal file
1
diffsynth/core/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .unified_dataset import UnifiedDataset
|
||||||
237
diffsynth/core/data/operators.py
Normal file
237
diffsynth/core/data/operators.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
import torch, torchvision, imageio, os
|
||||||
|
import imageio.v3 as iio
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingPipeline:
|
||||||
|
def __init__(self, operators=None):
|
||||||
|
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for operator in self.operators:
|
||||||
|
data = operator(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperator:
|
||||||
|
def __call__(self, data):
|
||||||
|
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ToInt(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return int(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToFloat(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return float(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToStr(DataProcessingOperator):
|
||||||
|
def __init__(self, none_value=""):
|
||||||
|
self.none_value = none_value
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
if data is None: data = self.none_value
|
||||||
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImage(DataProcessingOperator):
|
||||||
|
def __init__(self, convert_RGB=True, convert_RGBA=False):
|
||||||
|
self.convert_RGB = convert_RGB
|
||||||
|
self.convert_RGBA = convert_RGBA
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
image = Image.open(data)
|
||||||
|
if self.convert_RGB: image = image.convert("RGB")
|
||||||
|
if self.convert_RGBA: image = image.convert("RGBA")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCropAndResize(DataProcessingOperator):
|
||||||
|
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
|
||||||
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_height_width(self, image):
|
||||||
|
if self.height is None or self.width is None:
|
||||||
|
width, height = image.size
|
||||||
|
if width * height > self.max_pixels:
|
||||||
|
scale = (width * height / self.max_pixels) ** 0.5
|
||||||
|
height, width = int(height / scale), int(width / scale)
|
||||||
|
height = height // self.height_division_factor * self.height_division_factor
|
||||||
|
width = width // self.width_division_factor * self.width_division_factor
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
def __call__(self, data: Image.Image):
|
||||||
|
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ToList(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadVideo(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, reader):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
if int(reader.count_frames()) < num_frames:
|
||||||
|
num_frames = int(reader.count_frames())
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
reader = imageio.get_reader(data)
|
||||||
|
num_frames = self.get_num_frames(reader)
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(frame_id)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class SequencialProcess(DataProcessingOperator):
|
||||||
|
def __init__(self, operator=lambda x: x):
|
||||||
|
self.operator = operator
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return [self.operator(i) for i in data]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadGIF(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, path):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
images = iio.imread(path, mode="RGB")
|
||||||
|
if len(images) < num_frames:
|
||||||
|
num_frames = len(images)
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
num_frames = self.get_num_frames(data)
|
||||||
|
frames = []
|
||||||
|
images = iio.imread(data, mode="RGB")
|
||||||
|
for img in images:
|
||||||
|
frame = Image.fromarray(img)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
if len(frames) >= num_frames:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByExtensionName(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
file_ext_name = data.split(".")[-1].lower()
|
||||||
|
for ext_names, operator in self.operator_map:
|
||||||
|
if ext_names is None or file_ext_name in ext_names:
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported file: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByType(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for dtype, operator in self.operator_map:
|
||||||
|
if dtype is None or isinstance(data, dtype):
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported data: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTorchPickle(DataProcessingOperator):
|
||||||
|
def __init__(self, map_location="cpu"):
|
||||||
|
self.map_location = map_location
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ToAbsolutePath(DataProcessingOperator):
|
||||||
|
def __init__(self, base_path=""):
|
||||||
|
self.base_path = base_path
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return os.path.join(self.base_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadAudio(DataProcessingOperator):
|
||||||
|
def __init__(self, sr=16000):
|
||||||
|
self.sr = sr
|
||||||
|
def __call__(self, data: str):
|
||||||
|
import librosa
|
||||||
|
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||||
|
return input_audio
|
||||||
|
|
||||||
|
|
||||||
|
class LoadAudioWithTorchaudio(DataProcessingOperator):
|
||||||
|
def __init__(self, duration=5):
|
||||||
|
self.duration = duration
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
import torchaudio
|
||||||
|
waveform, sample_rate = torchaudio.load(data)
|
||||||
|
target_samples = int(self.duration * sample_rate)
|
||||||
|
current_samples = waveform.shape[-1]
|
||||||
|
if current_samples > target_samples:
|
||||||
|
waveform = waveform[..., :target_samples]
|
||||||
|
elif current_samples < target_samples:
|
||||||
|
padding = target_samples - current_samples
|
||||||
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
return waveform, sample_rate
|
||||||
116
diffsynth/core/data/unified_dataset.py
Normal file
116
diffsynth/core/data/unified_dataset.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from .operators import *
|
||||||
|
import torch, json, pandas
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_path=None, metadata_path=None,
|
||||||
|
repeat=1,
|
||||||
|
data_file_keys=tuple(),
|
||||||
|
main_data_operator=lambda x: x,
|
||||||
|
special_operator_map=None,
|
||||||
|
max_data_items=None,
|
||||||
|
):
|
||||||
|
self.base_path = base_path
|
||||||
|
self.metadata_path = metadata_path
|
||||||
|
self.repeat = repeat
|
||||||
|
self.data_file_keys = data_file_keys
|
||||||
|
self.main_data_operator = main_data_operator
|
||||||
|
self.cached_data_operator = LoadTorchPickle()
|
||||||
|
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||||
|
self.max_data_items = max_data_items
|
||||||
|
self.data = []
|
||||||
|
self.cached_data = []
|
||||||
|
self.load_from_cache = metadata_path is None
|
||||||
|
self.load_metadata(metadata_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_image_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||||
|
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_video_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||||
|
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||||
|
(("gif",), LoadGIF(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
])),
|
||||||
|
])
|
||||||
|
|
||||||
|
def search_for_cached_data_files(self, path):
|
||||||
|
for file_name in os.listdir(path):
|
||||||
|
subpath = os.path.join(path, file_name)
|
||||||
|
if os.path.isdir(subpath):
|
||||||
|
self.search_for_cached_data_files(subpath)
|
||||||
|
elif subpath.endswith(".pth"):
|
||||||
|
self.cached_data.append(subpath)
|
||||||
|
|
||||||
|
def load_metadata(self, metadata_path):
|
||||||
|
if metadata_path is None:
|
||||||
|
print("No metadata_path. Searching for cached data files.")
|
||||||
|
self.search_for_cached_data_files(self.base_path)
|
||||||
|
print(f"{len(self.cached_data)} cached data files found.")
|
||||||
|
elif metadata_path.endswith(".json"):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.data = metadata
|
||||||
|
elif metadata_path.endswith(".jsonl"):
|
||||||
|
metadata = []
|
||||||
|
with open(metadata_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
metadata.append(json.loads(line.strip()))
|
||||||
|
self.data = metadata
|
||||||
|
else:
|
||||||
|
metadata = pandas.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
if self.load_from_cache:
|
||||||
|
data = self.cached_data[data_id % len(self.cached_data)]
|
||||||
|
data = self.cached_data_operator(data)
|
||||||
|
else:
|
||||||
|
data = self.data[data_id % len(self.data)].copy()
|
||||||
|
for key in self.data_file_keys:
|
||||||
|
if key in data:
|
||||||
|
if key in self.special_operator_map:
|
||||||
|
data[key] = self.special_operator_map[key](data[key])
|
||||||
|
elif key in self.data_file_keys:
|
||||||
|
data[key] = self.main_data_operator(data[key])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.max_data_items is not None:
|
||||||
|
return self.max_data_items
|
||||||
|
elif self.load_from_cache:
|
||||||
|
return len(self.cached_data) * self.repeat
|
||||||
|
else:
|
||||||
|
return len(self.data) * self.repeat
|
||||||
|
|
||||||
|
def check_data_equal(self, data1, data2):
|
||||||
|
# Debug only
|
||||||
|
if len(data1) != len(data2):
|
||||||
|
return False
|
||||||
|
for k in data1:
|
||||||
|
if data1[k] != data2[k]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
2
diffsynth/core/device/__init__.py
Normal file
2
diffsynth/core/device/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||||
|
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||||
107
diffsynth/core/device/npu_compatible_device.py
Normal file
107
diffsynth/core/device/npu_compatible_device.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_npu_available():
|
||||||
|
return importlib.util.find_spec("torch_npu") is not None
|
||||||
|
|
||||||
|
|
||||||
|
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
|
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
torch.npu.config.allow_internal_format = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_type() -> str:
|
||||||
|
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
device = "cuda"
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
device = "npu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device() -> Any:
|
||||||
|
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
||||||
|
device_name = get_device_type()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return getattr(torch, device_name)
|
||||||
|
except AttributeError:
|
||||||
|
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
||||||
|
return torch.cuda
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_id() -> int:
|
||||||
|
"""Get current device id based on device type."""
|
||||||
|
return get_torch_device().current_device()
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name() -> str:
|
||||||
|
"""Get current device name based on device type."""
|
||||||
|
return f"{get_device_type()}:{get_device_id()}"
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize() -> None:
|
||||||
|
"""Execute torch synchronize operation."""
|
||||||
|
get_torch_device().synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache() -> None:
|
||||||
|
"""Execute torch empty cache operation."""
|
||||||
|
get_torch_device().empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_nccl_backend() -> str:
|
||||||
|
"""Return distributed communication backend type based on device type."""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
return "nccl"
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
return "hccl"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
||||||
|
|
||||||
|
|
||||||
|
def enable_high_precision_for_bf16():
|
||||||
|
"""
|
||||||
|
Set high accumulation dtype for matmul and reduction.
|
||||||
|
"""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.matmul.allow_tf32 = False
|
||||||
|
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_device_type(device):
|
||||||
|
if isinstance(device, str):
|
||||||
|
if device.startswith("cuda"):
|
||||||
|
return "cuda"
|
||||||
|
elif device.startswith("npu"):
|
||||||
|
return "npu"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
elif isinstance(device, torch.device):
|
||||||
|
return device.type
|
||||||
|
|
||||||
|
|
||||||
|
def parse_nccl_backend(device_type):
|
||||||
|
if device_type == "cuda":
|
||||||
|
return "nccl"
|
||||||
|
elif device_type == "npu":
|
||||||
|
return "hccl"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_device_type():
|
||||||
|
return get_device_type()
|
||||||
1
diffsynth/core/gradient/__init__.py
Normal file
1
diffsynth/core/gradient/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .gradient_checkpoint import gradient_checkpoint_forward
|
||||||
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs, **kwargs):
|
||||||
|
return module(*inputs, **kwargs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_checkpoint_forward(
|
||||||
|
model,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_output = model(*args, **kwargs)
|
||||||
|
return model_output
|
||||||
3
diffsynth/core/loader/__init__.py
Normal file
3
diffsynth/core/loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
||||||
|
from .model import load_model, load_model_with_disk_offload
|
||||||
|
from .config import ModelConfig
|
||||||
119
diffsynth/core/loader/config.py
Normal file
119
diffsynth/core/loader/config.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import torch, glob, os
|
||||||
|
from typing import Optional, Union, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
path: Union[str, list[str]] = None
|
||||||
|
model_id: str = None
|
||||||
|
origin_file_pattern: Union[str, list[str]] = None
|
||||||
|
download_source: str = None
|
||||||
|
local_model_path: str = None
|
||||||
|
skip_download: bool = None
|
||||||
|
offload_device: Optional[Union[str, torch.device]] = None
|
||||||
|
offload_dtype: Optional[torch.dtype] = None
|
||||||
|
onload_device: Optional[Union[str, torch.device]] = None
|
||||||
|
onload_dtype: Optional[torch.dtype] = None
|
||||||
|
preparing_device: Optional[Union[str, torch.device]] = None
|
||||||
|
preparing_dtype: Optional[torch.dtype] = None
|
||||||
|
computation_device: Optional[Union[str, torch.device]] = None
|
||||||
|
computation_dtype: Optional[torch.dtype] = None
|
||||||
|
clear_parameters: bool = False
|
||||||
|
state_dict: Dict[str, torch.Tensor] = None
|
||||||
|
|
||||||
|
def check_input(self):
|
||||||
|
if self.path is None and self.model_id is None:
|
||||||
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||||
|
|
||||||
|
def parse_original_file_pattern(self):
|
||||||
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
|
return "*"
|
||||||
|
elif self.origin_file_pattern.endswith("/"):
|
||||||
|
return self.origin_file_pattern + "*"
|
||||||
|
else:
|
||||||
|
return self.origin_file_pattern
|
||||||
|
|
||||||
|
def parse_download_source(self):
|
||||||
|
if self.download_source is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
||||||
|
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
||||||
|
else:
|
||||||
|
return "modelscope"
|
||||||
|
else:
|
||||||
|
return self.download_source
|
||||||
|
|
||||||
|
def parse_skip_download(self):
|
||||||
|
if self.skip_download is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||||
|
return True
|
||||||
|
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self.skip_download
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
origin_file_pattern = self.parse_original_file_pattern()
|
||||||
|
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||||
|
download_source = self.parse_download_source()
|
||||||
|
if download_source.lower() == "modelscope":
|
||||||
|
snapshot_download(
|
||||||
|
self.model_id,
|
||||||
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
|
allow_file_pattern=origin_file_pattern,
|
||||||
|
ignore_file_pattern=downloaded_files,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
elif download_source.lower() == "huggingface":
|
||||||
|
hf_snapshot_download(
|
||||||
|
self.model_id,
|
||||||
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
|
allow_patterns=origin_file_pattern,
|
||||||
|
ignore_patterns=downloaded_files,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
||||||
|
|
||||||
|
def require_downloading(self):
|
||||||
|
if self.path is not None:
|
||||||
|
return False
|
||||||
|
skip_download = self.parse_skip_download()
|
||||||
|
return not skip_download
|
||||||
|
|
||||||
|
def reset_local_model_path(self):
|
||||||
|
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||||
|
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
||||||
|
elif self.local_model_path is None:
|
||||||
|
self.local_model_path = "./models"
|
||||||
|
|
||||||
|
def download_if_necessary(self):
|
||||||
|
self.check_input()
|
||||||
|
self.reset_local_model_path()
|
||||||
|
if self.require_downloading():
|
||||||
|
self.download()
|
||||||
|
if self.path is None:
|
||||||
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
|
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||||
|
else:
|
||||||
|
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||||
|
if isinstance(self.path, list) and len(self.path) == 1:
|
||||||
|
self.path = self.path[0]
|
||||||
|
|
||||||
|
def vram_config(self):
|
||||||
|
return {
|
||||||
|
"offload_device": self.offload_device,
|
||||||
|
"offload_dtype": self.offload_dtype,
|
||||||
|
"onload_device": self.onload_device,
|
||||||
|
"onload_dtype": self.onload_dtype,
|
||||||
|
"preparing_device": self.preparing_device,
|
||||||
|
"preparing_dtype": self.preparing_dtype,
|
||||||
|
"computation_device": self.computation_device,
|
||||||
|
"computation_dtype": self.computation_dtype,
|
||||||
|
}
|
||||||
130
diffsynth/core/loader/file.py
Normal file
130
diffsynth/core/loader/file.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
import torch, hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for file_path_ in file_path:
|
||||||
|
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||||
|
else:
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [started]: {file_path}")
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||||
|
if pin_memory:
|
||||||
|
for i in state_dict:
|
||||||
|
state_dict[i] = state_dict[i].pin_memory()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [done]: {file_path}")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
state_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
state_dict[k] = f.get_tensor(k)
|
||||||
|
if torch_dtype is not None:
|
||||||
|
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||||
|
if len(state_dict) == 1:
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
elif "module" in state_dict:
|
||||||
|
state_dict = state_dict["module"]
|
||||||
|
elif "model_state" in state_dict:
|
||||||
|
state_dict = state_dict["model_state"]
|
||||||
|
if torch_dtype is not None:
|
||||||
|
for i in state_dict:
|
||||||
|
if isinstance(state_dict[i], torch.Tensor):
|
||||||
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value.shape)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict(file_path):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for file_path_ in file_path:
|
||||||
|
state_dict.update(load_keys_dict(file_path_))
|
||||||
|
return state_dict
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
return load_keys_dict_from_safetensors(file_path)
|
||||||
|
else:
|
||||||
|
return load_keys_dict_from_bin(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict_from_safetensors(file_path):
|
||||||
|
keys_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
keys_dict[k] = f.get_slice(k).get_shape()
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_to_keys_dict(state_dict):
|
||||||
|
keys_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
keys_dict[k] = list(v.shape)
|
||||||
|
else:
|
||||||
|
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict_from_bin(file_path):
|
||||||
|
state_dict = load_state_dict_from_bin(file_path)
|
||||||
|
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
||||||
|
else:
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
|
||||||
|
def hash_model_file(path, with_shape=True):
|
||||||
|
keys_dict = load_keys_dict(path)
|
||||||
|
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
105
diffsynth/core/loader/model.py
Normal file
105
diffsynth/core/loader/model.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from ..vram.initialization import skip_model_initialization
|
||||||
|
from ..vram.disk_map import DiskMap
|
||||||
|
from ..vram.layers import enable_vram_management
|
||||||
|
from .file import load_state_dict
|
||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
|
config = {} if config is None else config
|
||||||
|
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||||
|
model = model_class(**config)
|
||||||
|
# What is `module_map`?
|
||||||
|
# This is a module mapping table for VRAM management.
|
||||||
|
if module_map is not None:
|
||||||
|
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||||
|
device = [d for d in devices if d != "disk"][0]
|
||||||
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||||
|
dtype = [d for d in dtypes if d != "disk"][0]
|
||||||
|
if vram_config["offload_device"] != "disk":
|
||||||
|
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||||
|
else:
|
||||||
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||||
|
else:
|
||||||
|
# Why do we use `DiskMap`?
|
||||||
|
# Sometimes a model file contains multiple models,
|
||||||
|
# and DiskMap can load only the parameters of a single model,
|
||||||
|
# avoiding the need to load all parameters in the file.
|
||||||
|
if state_dict is not None:
|
||||||
|
pass
|
||||||
|
elif use_disk_map:
|
||||||
|
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(path, torch_dtype, device)
|
||||||
|
# Why do we use `state_dict_converter`?
|
||||||
|
# Some models are saved in complex formats,
|
||||||
|
# and we need to convert the state dict into the appropriate format.
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||||
|
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||||
|
# Loading them directly could lead to excessive GPU memory consumption.
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
|
_load_state_dict_into_zero3_model(model, state_dict)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
# Why do we call `to()`?
|
||||||
|
# Because some models override the behavior of `to()`,
|
||||||
|
# especially those from libraries like Transformers.
|
||||||
|
model = model.to(dtype=torch_dtype, device=device)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = [path]
|
||||||
|
config = {} if config is None else config
|
||||||
|
with skip_model_initialization():
|
||||||
|
model = model_class(**config)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch_dtype,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_init_context(torch_dtype, device):
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.modeling_utils import set_zero3_state
|
||||||
|
import deepspeed
|
||||||
|
# Why do we use "deepspeed.zero.Init"?
|
||||||
|
# Weight segmentation of the model can be performed on the CPU side
|
||||||
|
# and loading the segmented weights onto the computing card
|
||||||
|
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||||
|
else:
|
||||||
|
# Why do we use `skip_model_initialization`?
|
||||||
|
# It skips the random initialization of model parameters,
|
||||||
|
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||||
|
init_contexts = [skip_model_initialization()]
|
||||||
|
|
||||||
|
return init_contexts
|
||||||
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from ..device.npu_compatible_device import get_device_type
|
||||||
|
try:
|
||||||
|
import torch_npu
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_forward_npu(self, hidden_states):
|
||||||
|
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
||||||
|
if hidden_states.dtype != self.weight.dtype:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_forward_transformers_npu(self, hidden_states):
|
||||||
|
"npu rms fused operator for transformers"
|
||||||
|
if hidden_states.dtype != self.weight.dtype:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
||||||
|
"npu rope fused operator for Zimage"
|
||||||
|
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
||||||
|
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||||
|
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||||
|
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
||||||
2
diffsynth/core/vram/__init__.py
Normal file
2
diffsynth/core/vram/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .initialization import skip_model_initialization
|
||||||
|
from .layers import *
|
||||||
93
diffsynth/core/vram/disk_map.py
Normal file
93
diffsynth/core/vram/disk_map.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
import torch, os
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsCompatibleTensor:
|
||||||
|
def __init__(self, tensor):
|
||||||
|
self.tensor = tensor
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
return list(self.tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsCompatibleBinaryLoader:
|
||||||
|
def __init__(self, path, device):
|
||||||
|
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
||||||
|
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.state_dict.keys()
|
||||||
|
|
||||||
|
def get_tensor(self, name):
|
||||||
|
return self.state_dict[name]
|
||||||
|
|
||||||
|
def get_slice(self, name):
|
||||||
|
return SafetensorsCompatibleTensor(self.state_dict[name])
|
||||||
|
|
||||||
|
|
||||||
|
class DiskMap:
|
||||||
|
|
||||||
|
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
||||||
|
self.path = path if isinstance(path, list) else [path]
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
||||||
|
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
||||||
|
else:
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
self.files = []
|
||||||
|
self.flush_files()
|
||||||
|
self.name_map = {}
|
||||||
|
for file_id, file in enumerate(self.files):
|
||||||
|
for name in file.keys():
|
||||||
|
self.name_map[name] = file_id
|
||||||
|
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
||||||
|
|
||||||
|
def flush_files(self):
|
||||||
|
if len(self.files) == 0:
|
||||||
|
for path in self.path:
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
||||||
|
else:
|
||||||
|
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
||||||
|
else:
|
||||||
|
for i, path in enumerate(self.path):
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
||||||
|
self.num_params = 0
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||||
|
file_id = self.name_map[name]
|
||||||
|
param = self.files[file_id].get_tensor(name)
|
||||||
|
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||||
|
param = param.to(self.torch_dtype)
|
||||||
|
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
||||||
|
param = param.clone()
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
self.num_params += param.numel()
|
||||||
|
if self.num_params > self.buffer_size:
|
||||||
|
self.flush_files()
|
||||||
|
return param
|
||||||
|
|
||||||
|
def fetch_rename_dict(self, state_dict_converter):
|
||||||
|
if state_dict_converter is None:
|
||||||
|
return None
|
||||||
|
state_dict = {}
|
||||||
|
for file in self.files:
|
||||||
|
for name in file.keys():
|
||||||
|
state_dict[name] = name
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.rename_dict is not None:
|
||||||
|
return self.rename_dict.__iter__()
|
||||||
|
else:
|
||||||
|
return self.name_map.__iter__()
|
||||||
|
|
||||||
|
def __contains__(self, x):
|
||||||
|
if self.rename_dict is not None:
|
||||||
|
return x in self.rename_dict
|
||||||
|
else:
|
||||||
|
return x in self.name_map
|
||||||
21
diffsynth/core/vram/initialization.py
Normal file
21
diffsynth/core/vram/initialization.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def skip_model_initialization(device=torch.device("meta")):
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
old_register_parameter = torch.nn.Module.register_parameter
|
||||||
|
torch.nn.Module.register_parameter = register_empty_parameter
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Module.register_parameter = old_register_parameter
|
||||||
479
diffsynth/core/vram/layers.py
Normal file
479
diffsynth/core/vram/layers.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
import torch, copy
|
||||||
|
from typing import Union
|
||||||
|
from .initialization import skip_model_initialization
|
||||||
|
from .disk_map import DiskMap
|
||||||
|
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class AutoTorchModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.set_dtype_and_device(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.state = 0
|
||||||
|
self.name = ""
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
|
def set_dtype_and_device(
|
||||||
|
self,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
self.offload_dtype = offload_dtype or computation_dtype
|
||||||
|
self.offload_device = offload_device or computation_dtype
|
||||||
|
self.onload_dtype = onload_dtype or computation_dtype
|
||||||
|
self.onload_device = onload_device or computation_dtype
|
||||||
|
self.preparing_dtype = preparing_dtype or computation_dtype
|
||||||
|
self.preparing_device = preparing_device or computation_dtype
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.vram_limit = vram_limit
|
||||||
|
|
||||||
|
def cast_to(self, weight, dtype, device):
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def check_free_vram(self):
|
||||||
|
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||||
|
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||||
|
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||||
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state != 0:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state != 1:
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def param_name(self, name):
|
||||||
|
if self.name == "":
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
return self.name + "." + name
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedModule(AutoTorchModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.module = module
|
||||||
|
if offload_dtype == "disk":
|
||||||
|
self.name = name
|
||||||
|
self.disk_map = disk_map
|
||||||
|
self.required_params = [name for name, _ in self.module.named_parameters()]
|
||||||
|
self.disk_offload = True
|
||||||
|
else:
|
||||||
|
self.disk_offload = False
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||||
|
if copy_module:
|
||||||
|
module = copy.deepcopy(self.module)
|
||||||
|
else:
|
||||||
|
module = self.module
|
||||||
|
state_dict = {}
|
||||||
|
for name in self.required_params:
|
||||||
|
param = self.disk_map[self.param_name(name)]
|
||||||
|
param = param.to(dtype=torch_dtype, device=device)
|
||||||
|
state_dict[name] = param
|
||||||
|
module.load_state_dict(state_dict, assign=True)
|
||||||
|
module.to(dtype=torch_dtype, device=device)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def offload_to_disk(self, model: torch.nn.Module):
|
||||||
|
for buf in model.buffers():
|
||||||
|
# If there are some parameters are registed in buffers (not in state dict),
|
||||||
|
# We cannot offload the model.
|
||||||
|
for children in model.children():
|
||||||
|
self.offload_to_disk(children)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
model.to("meta")
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
# offload / onload / preparing -> offload
|
||||||
|
if self.state != 0:
|
||||||
|
if self.disk_offload:
|
||||||
|
self.offload_to_disk(self.module)
|
||||||
|
else:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
# offload / onload / preparing -> onload
|
||||||
|
if self.state < 1:
|
||||||
|
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||||
|
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||||
|
elif self.onload_device != "disk":
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def preparing(self):
|
||||||
|
# onload / preparing -> preparing
|
||||||
|
if self.state != 2:
|
||||||
|
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||||
|
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||||
|
elif self.preparing_device != "disk":
|
||||||
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def computation(self):
|
||||||
|
# onload / preparing -> computation (temporary)
|
||||||
|
if self.state == 2:
|
||||||
|
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||||
|
else:
|
||||||
|
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||||
|
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||||
|
module = self.module
|
||||||
|
elif self.disk_offload and device == "disk":
|
||||||
|
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||||
|
else:
|
||||||
|
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||||
|
self.preparing()
|
||||||
|
module = self.computation()
|
||||||
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.__dict__ or name == "module":
|
||||||
|
return super().__getattr__(name)
|
||||||
|
else:
|
||||||
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
module,
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
name,
|
||||||
|
disk_map,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if self.disk_offload:
|
||||||
|
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||||
|
if copy_module:
|
||||||
|
module = copy.deepcopy(self.module)
|
||||||
|
else:
|
||||||
|
module = self.module
|
||||||
|
state_dict = {}
|
||||||
|
for name in self.required_params:
|
||||||
|
param = self.disk_map[self.param_name(name)]
|
||||||
|
param = param.to(dtype=torch_dtype, device=device)
|
||||||
|
state_dict[name] = param
|
||||||
|
module.load_state_dict(state_dict, assign=True, strict=False)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def offload_to_disk(self, model: torch.nn.Module):
|
||||||
|
for name in self.required_params:
|
||||||
|
getattr(self, name).to("meta")
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
# Parameter casting is implemented in the model architecture.
|
||||||
|
return module
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.__dict__ or name == "module":
|
||||||
|
return super().__getattr__(name)
|
||||||
|
else:
|
||||||
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Linear,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with skip_model_initialization():
|
||||||
|
super().__init__(
|
||||||
|
in_features=module.in_features,
|
||||||
|
out_features=module.out_features,
|
||||||
|
bias=module.bias is not None,
|
||||||
|
)
|
||||||
|
self.set_dtype_and_device(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.weight = module.weight
|
||||||
|
self.bias = module.bias
|
||||||
|
self.state = 0
|
||||||
|
self.name = name
|
||||||
|
self.lora_A_weights = []
|
||||||
|
self.lora_B_weights = []
|
||||||
|
self.lora_merger = None
|
||||||
|
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
|
if offload_dtype == "disk":
|
||||||
|
self.disk_map = disk_map
|
||||||
|
self.disk_offload = True
|
||||||
|
else:
|
||||||
|
self.disk_offload = False
|
||||||
|
|
||||||
|
def fp8_linear(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = input.device
|
||||||
|
origin_dtype = input.dtype
|
||||||
|
origin_shape = input.shape
|
||||||
|
input = input.reshape(-1, origin_shape[-1])
|
||||||
|
|
||||||
|
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||||
|
fp8_max = 448.0
|
||||||
|
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||||
|
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||||
|
# we scale down the input by 2.0 in advance.
|
||||||
|
# This scaling will be compensated later during the final result scaling.
|
||||||
|
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||||
|
fp8_max = fp8_max / 2.0
|
||||||
|
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||||
|
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||||
|
input = input / (scale_a + 1e-8)
|
||||||
|
input = input.to(self.computation_dtype)
|
||||||
|
weight = weight.to(self.computation_dtype)
|
||||||
|
bias = bias.to(torch.bfloat16)
|
||||||
|
|
||||||
|
result = torch._scaled_mm(
|
||||||
|
input,
|
||||||
|
weight.T,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b.T,
|
||||||
|
bias=bias,
|
||||||
|
out_dtype=origin_dtype,
|
||||||
|
)
|
||||||
|
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||||
|
result = result.reshape(new_shape)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, assign=True):
|
||||||
|
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
||||||
|
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
||||||
|
if assign:
|
||||||
|
state_dict = {"weight": weight}
|
||||||
|
if bias is not None: state_dict["bias"] = bias
|
||||||
|
self.load_state_dict(state_dict, assign=True)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
# offload / onload / preparing -> offload
|
||||||
|
if self.state != 0:
|
||||||
|
if self.disk_offload:
|
||||||
|
self.to("meta")
|
||||||
|
else:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
# offload / onload / preparing -> onload
|
||||||
|
if self.state < 1:
|
||||||
|
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||||
|
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||||
|
elif self.onload_device != "disk":
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def preparing(self):
|
||||||
|
# onload / preparing -> preparing
|
||||||
|
if self.state != 2:
|
||||||
|
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||||
|
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||||
|
elif self.preparing_device != "disk":
|
||||||
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
def computation(self):
|
||||||
|
# onload / preparing -> computation (temporary)
|
||||||
|
if self.state == 2:
|
||||||
|
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||||
|
else:
|
||||||
|
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||||
|
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
elif self.disk_offload and device == "disk":
|
||||||
|
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
||||||
|
else:
|
||||||
|
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
def linear_forward(self, x, weight, bias):
|
||||||
|
if self.enable_fp8:
|
||||||
|
out = self.fp8_linear(x, weight, bias)
|
||||||
|
else:
|
||||||
|
out = torch.nn.functional.linear(x, weight, bias)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def lora_forward(self, x, out):
|
||||||
|
if self.lora_merger is None:
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
out = out + x @ lora_A.T @ lora_B.T
|
||||||
|
else:
|
||||||
|
lora_output = []
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||||
|
lora_output = torch.stack(lora_output)
|
||||||
|
out = self.lora_merger(out, lora_output)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||||
|
self.preparing()
|
||||||
|
weight, bias = self.computation()
|
||||||
|
out = self.linear_forward(x, weight, bias)
|
||||||
|
if len(self.lora_A_weights) > 0:
|
||||||
|
out = self.lora_forward(x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
||||||
|
if isinstance(model, AutoWrappedNonRecurseModule):
|
||||||
|
model = model.module
|
||||||
|
for name, module in model.named_children():
|
||||||
|
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
if isinstance(module, source_module):
|
||||||
|
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
if isinstance(module_, AutoWrappedNonRecurseModule):
|
||||||
|
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
setattr(model, name, module_)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def fill_vram_config(model, vram_config):
|
||||||
|
vram_config_ = vram_config.copy()
|
||||||
|
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
||||||
|
vram_config_["onload_device"] = vram_config["computation_device"]
|
||||||
|
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
||||||
|
vram_config_["preparing_device"] = vram_config["computation_device"]
|
||||||
|
for k in vram_config:
|
||||||
|
if vram_config[k] != vram_config_[k]:
|
||||||
|
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
||||||
|
break
|
||||||
|
return vram_config_
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
||||||
|
if isinstance(model, source_module):
|
||||||
|
vram_config = fill_vram_config(model, vram_config)
|
||||||
|
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||||
|
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
||||||
|
model.vram_management_enabled = True
|
||||||
|
return model
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .video import VideoData, save_video, save_frames
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
import torch, os, torchvision
|
|
||||||
from torchvision import transforms
|
|
||||||
import pandas as pd
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextImageDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
|
||||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
|
||||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
|
||||||
self.text = metadata["text"].to_list()
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.image_processor = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
|
||||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
|
||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
|
||||||
text = self.text[data_id]
|
|
||||||
image = Image.open(self.path[data_id]).convert("RGB")
|
|
||||||
target_height, target_width = self.height, self.width
|
|
||||||
width, height = image.size
|
|
||||||
scale = max(target_width / width, target_height / height)
|
|
||||||
shape = [round(height*scale),round(width*scale)]
|
|
||||||
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
|
||||||
image = self.image_processor(image)
|
|
||||||
return {"text": text, "image": image}
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.steps_per_epoch
|
|
||||||
6
diffsynth/diffusion/__init__.py
Normal file
6
diffsynth/diffusion/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .flow_match import FlowMatchScheduler
|
||||||
|
from .training_module import DiffusionTrainingModule
|
||||||
|
from .logger import ModelLogger
|
||||||
|
from .runner import launch_training_task, launch_data_process_task
|
||||||
|
from .parsers import *
|
||||||
|
from .loss import *
|
||||||
459
diffsynth/diffusion/base_pipeline.py
Normal file
459
diffsynth/diffusion/base_pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat, reduce
|
||||||
|
from typing import Union
|
||||||
|
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..utils.lora import GeneralLoRALoader
|
||||||
|
from ..models.model_loader import ModelPool
|
||||||
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnit:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seperate_cfg: bool = False,
|
||||||
|
take_over: bool = False,
|
||||||
|
input_params: tuple[str] = None,
|
||||||
|
output_params: tuple[str] = None,
|
||||||
|
input_params_posi: dict[str, str] = None,
|
||||||
|
input_params_nega: dict[str, str] = None,
|
||||||
|
onload_model_names: tuple[str] = None
|
||||||
|
):
|
||||||
|
self.seperate_cfg = seperate_cfg
|
||||||
|
self.take_over = take_over
|
||||||
|
self.input_params = input_params
|
||||||
|
self.output_params = output_params
|
||||||
|
self.input_params_posi = input_params_posi
|
||||||
|
self.input_params_nega = input_params_nega
|
||||||
|
self.onload_model_names = onload_model_names
|
||||||
|
|
||||||
|
def fetch_input_params(self):
|
||||||
|
params = []
|
||||||
|
if self.input_params is not None:
|
||||||
|
for param in self.input_params:
|
||||||
|
params.append(param)
|
||||||
|
if self.input_params_posi is not None:
|
||||||
|
for _, param in self.input_params_posi.items():
|
||||||
|
params.append(param)
|
||||||
|
if self.input_params_nega is not None:
|
||||||
|
for _, param in self.input_params_nega.items():
|
||||||
|
params.append(param)
|
||||||
|
params = sorted(list(set(params)))
|
||||||
|
return params
|
||||||
|
|
||||||
|
def fetch_output_params(self):
|
||||||
|
params = []
|
||||||
|
if self.output_params is not None:
|
||||||
|
for param in self.output_params:
|
||||||
|
params.append(param)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def process(self, pipe, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def post_process(self, pipe, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device=get_device_type(), torch_dtype=torch.float16,
|
||||||
|
height_division_factor=64, width_division_factor=64,
|
||||||
|
time_division_factor=None, time_division_remainder=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.device_type = parse_device_type(device)
|
||||||
|
# The following parameters are used for shape check.
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# VRAM management
|
||||||
|
self.vram_management_enabled = False
|
||||||
|
# Pipeline Unit Runner
|
||||||
|
self.unit_runner = PipelineUnitRunner()
|
||||||
|
# LoRA Loader
|
||||||
|
self.lora_loader = GeneralLoRALoader
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
if device is not None:
|
||||||
|
self.device = device
|
||||||
|
if dtype is not None:
|
||||||
|
self.torch_dtype = dtype
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def check_resize_height_width(self, height, width, num_frames=None):
|
||||||
|
# Shape check
|
||||||
|
if height % self.height_division_factor != 0:
|
||||||
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||||
|
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||||
|
if width % self.width_division_factor != 0:
|
||||||
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||||
|
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||||
|
if num_frames is None:
|
||||||
|
return height, width
|
||||||
|
else:
|
||||||
|
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||||
|
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||||
|
return height, width, num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a PIL.Image to torch.Tensor
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||||
|
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||||
|
image = image * ((max_value - min_value) / 255) + min_value
|
||||||
|
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a list of PIL.Image to torch.Tensor
|
||||||
|
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||||
|
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a torch.Tensor to PIL.Image
|
||||||
|
if pattern != "H W C":
|
||||||
|
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||||
|
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||||
|
image = image.to(device="cpu", dtype=torch.uint8)
|
||||||
|
image = Image.fromarray(image.numpy())
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a torch.Tensor to list of PIL.Image
|
||||||
|
if pattern != "T H W C":
|
||||||
|
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||||
|
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_to_device(self, model_names):
|
||||||
|
if self.vram_management_enabled:
|
||||||
|
# offload models
|
||||||
|
for name, model in self.named_children():
|
||||||
|
if name not in model_names:
|
||||||
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
if hasattr(model, "offload"):
|
||||||
|
model.offload()
|
||||||
|
else:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "offload"):
|
||||||
|
module.offload()
|
||||||
|
getattr(torch, self.device_type).empty_cache()
|
||||||
|
# onload models
|
||||||
|
for name, model in self.named_children():
|
||||||
|
if name in model_names:
|
||||||
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
if hasattr(model, "onload"):
|
||||||
|
model.onload()
|
||||||
|
else:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "onload"):
|
||||||
|
module.onload()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||||
|
# Initialize Gaussian noise
|
||||||
|
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||||
|
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||||
|
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
def get_vram(self):
|
||||||
|
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||||
|
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||||
|
|
||||||
|
def get_module(self, model, name):
|
||||||
|
if "." in name:
|
||||||
|
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
||||||
|
if name.isdigit():
|
||||||
|
return self.get_module(model[int(name)], suffix)
|
||||||
|
else:
|
||||||
|
return self.get_module(getattr(model, name), suffix)
|
||||||
|
else:
|
||||||
|
return getattr(model, name)
|
||||||
|
|
||||||
|
def freeze_except(self, model_names):
|
||||||
|
self.eval()
|
||||||
|
self.requires_grad_(False)
|
||||||
|
for name in model_names:
|
||||||
|
module = self.get_module(self, name)
|
||||||
|
if module is None:
|
||||||
|
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
||||||
|
continue
|
||||||
|
module.train()
|
||||||
|
module.requires_grad_(True)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_with_mask(self, base, addition, mask):
|
||||||
|
return base * (1 - mask) + addition * mask
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||||
|
timestep = scheduler.timesteps[progress_id]
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||||
|
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||||
|
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||||
|
return latents_next
|
||||||
|
|
||||||
|
|
||||||
|
def split_pipeline_units(self, model_names: list[str]):
|
||||||
|
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
||||||
|
|
||||||
|
|
||||||
|
def flush_vram_management_device(self, device):
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, AutoTorchModule):
|
||||||
|
module.offload_device = device
|
||||||
|
module.onload_device = device
|
||||||
|
module.preparing_device = device
|
||||||
|
module.computation_device = device
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
lora_config: Union[ModelConfig, str] = None,
|
||||||
|
alpha=1,
|
||||||
|
hotload=None,
|
||||||
|
state_dict=None,
|
||||||
|
verbose=1,
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
if isinstance(lora_config, str):
|
||||||
|
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora_config.download_if_necessary()
|
||||||
|
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora = state_dict
|
||||||
|
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
lora = lora_loader.convert_state_dict(lora)
|
||||||
|
if hotload is None:
|
||||||
|
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||||
|
if hotload:
|
||||||
|
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||||
|
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||||
|
updated_num = 0
|
||||||
|
for _, module in module.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
name = module.name
|
||||||
|
lora_a_name = f'{name}.lora_A.weight'
|
||||||
|
lora_b_name = f'{name}.lora_B.weight'
|
||||||
|
if lora_a_name in lora and lora_b_name in lora:
|
||||||
|
updated_num += 1
|
||||||
|
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||||
|
module.lora_B_weights.append(lora[lora_b_name])
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||||
|
else:
|
||||||
|
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_lora(self, verbose=1):
|
||||||
|
cleared_num = 0
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
if hasattr(module, "lora_A_weights"):
|
||||||
|
if len(module.lora_A_weights) > 0:
|
||||||
|
cleared_num += 1
|
||||||
|
module.lora_A_weights.clear()
|
||||||
|
if hasattr(module, "lora_B_weights"):
|
||||||
|
module.lora_B_weights.clear()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"{cleared_num} LoRA layers are cleared.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||||
|
model_pool = ModelPool()
|
||||||
|
for model_config in model_configs:
|
||||||
|
model_config.download_if_necessary()
|
||||||
|
vram_config = model_config.vram_config()
|
||||||
|
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
||||||
|
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
||||||
|
model_pool.auto_load_model(
|
||||||
|
model_config.path,
|
||||||
|
vram_config=vram_config,
|
||||||
|
vram_limit=vram_limit,
|
||||||
|
clear_parameters=model_config.clear_parameters,
|
||||||
|
state_dict=model_config.state_dict,
|
||||||
|
)
|
||||||
|
return model_pool
|
||||||
|
|
||||||
|
|
||||||
|
def check_vram_management_state(self):
|
||||||
|
vram_management_enabled = False
|
||||||
|
for module in self.children():
|
||||||
|
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
||||||
|
vram_management_enabled = True
|
||||||
|
return vram_management_enabled
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||||
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
|
self.clear_lora(verbose=0)
|
||||||
|
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||||
|
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
|
self.clear_lora(verbose=0)
|
||||||
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||||
|
if isinstance(noise_pred_posi, tuple):
|
||||||
|
# Separately handling different output types of latents, eg. video and audio latents.
|
||||||
|
noise_pred = tuple(
|
||||||
|
n_nega + cfg_scale * (n_posi - n_nega)
|
||||||
|
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnitGraph:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_edges(self, units: list[PipelineUnit]):
|
||||||
|
# Establish dependencies between units
|
||||||
|
# to search for subsequent related computation units.
|
||||||
|
last_compute_unit_id = {}
|
||||||
|
edges = []
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for input_param in unit.fetch_input_params():
|
||||||
|
if input_param in last_compute_unit_id:
|
||||||
|
edges.append((last_compute_unit_id[input_param], unit_id))
|
||||||
|
for output_param in unit.fetch_output_params():
|
||||||
|
last_compute_unit_id[output_param] = unit_id
|
||||||
|
return edges
|
||||||
|
|
||||||
|
def build_chains(self, units: list[PipelineUnit]):
|
||||||
|
# Establish updating chains for each variable
|
||||||
|
# to track their computation process.
|
||||||
|
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
||||||
|
params = sorted(list(set(params)))
|
||||||
|
chains = {param: [] for param in params}
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for param in unit.fetch_output_params():
|
||||||
|
chains[param].append(unit_id)
|
||||||
|
return chains
|
||||||
|
|
||||||
|
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
||||||
|
# Search for units that directly participate in the model's computation.
|
||||||
|
related_unit_ids = []
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for model_name in model_names:
|
||||||
|
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
||||||
|
related_unit_ids.append(unit_id)
|
||||||
|
break
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
||||||
|
# Search for subsequent related computation units.
|
||||||
|
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
||||||
|
while True:
|
||||||
|
neighbors = []
|
||||||
|
for source, target in edges:
|
||||||
|
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
||||||
|
neighbors.append(target)
|
||||||
|
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
||||||
|
neighbors.append(source)
|
||||||
|
neighbors = sorted(list(set(neighbors)))
|
||||||
|
if len(neighbors) == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
related_unit_ids.extend(neighbors)
|
||||||
|
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
||||||
|
# If the input parameters of this subgraph are updated outside the subgraph,
|
||||||
|
# search for the units where these updates occur.
|
||||||
|
first_compute_unit_id = {}
|
||||||
|
for unit_id in related_unit_ids:
|
||||||
|
for param in units[unit_id].fetch_input_params():
|
||||||
|
if param not in first_compute_unit_id:
|
||||||
|
first_compute_unit_id[param] = unit_id
|
||||||
|
updating_unit_ids = []
|
||||||
|
for param in first_compute_unit_id:
|
||||||
|
unit_id = first_compute_unit_id[param]
|
||||||
|
chain = chains[param]
|
||||||
|
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
||||||
|
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
||||||
|
if unit_id_ not in related_unit_ids:
|
||||||
|
updating_unit_ids.append(unit_id_)
|
||||||
|
related_unit_ids.extend(updating_unit_ids)
|
||||||
|
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
||||||
|
# Split the computation graph,
|
||||||
|
# separating all model-related computations.
|
||||||
|
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
||||||
|
edges = self.build_edges(units)
|
||||||
|
chains = self.build_chains(units)
|
||||||
|
while True:
|
||||||
|
num_related_unit_ids = len(related_unit_ids)
|
||||||
|
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
||||||
|
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
||||||
|
if len(related_unit_ids) == num_related_unit_ids:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
num_related_unit_ids = len(related_unit_ids)
|
||||||
|
related_units = [units[i] for i in related_unit_ids]
|
||||||
|
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
||||||
|
return related_units, unrelated_units
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnitRunner:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||||
|
if unit.take_over:
|
||||||
|
# Let the pipeline unit take over this function.
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||||
|
elif unit.seperate_cfg:
|
||||||
|
# Positive side
|
||||||
|
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||||
|
if unit.input_params is not None:
|
||||||
|
for name in unit.input_params:
|
||||||
|
processor_inputs[name] = inputs_shared.get(name)
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_posi.update(processor_outputs)
|
||||||
|
# Negative side
|
||||||
|
if inputs_shared["cfg_scale"] != 1:
|
||||||
|
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||||
|
if unit.input_params is not None:
|
||||||
|
for name in unit.input_params:
|
||||||
|
processor_inputs[name] = inputs_shared.get(name)
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_nega.update(processor_outputs)
|
||||||
|
else:
|
||||||
|
inputs_nega.update(processor_outputs)
|
||||||
|
else:
|
||||||
|
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_shared.update(processor_outputs)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
236
diffsynth/diffusion/flow_match.py
Normal file
236
diffsynth/diffusion/flow_match.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
import torch, math
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class FlowMatchScheduler():
|
||||||
|
|
||||||
|
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||||
|
self.set_timesteps_fn = {
|
||||||
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||||
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||||
|
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||||
|
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||||
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
|
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||||
|
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||||
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
|
self.num_train_timesteps = 1000
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.003/1.002
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 5 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
shift_terminal = 0.02
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||||
|
sigmas = 1 - (one_minus_z / scale_factor)
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
base_shift = math.log(3)
|
||||||
|
max_shift = math.log(3)
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_empirical_mu(image_seq_len, num_steps):
|
||||||
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
|
if image_seq_len > 4300:
|
||||||
|
mu = a2 * image_seq_len + b2
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
m_200 = a2 * image_seq_len + b2
|
||||||
|
m_10 = a1 * image_seq_len + b1
|
||||||
|
|
||||||
|
a = (m_200 - m_10) / 190.0
|
||||||
|
b = m_200 - 200.0 * a
|
||||||
|
mu = a * num_steps + b
|
||||||
|
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||||
|
sigma_min = 1 / num_inference_steps
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
if dynamic_shift_len is None:
|
||||||
|
# If you ask me why I set mu=0.8,
|
||||||
|
# I can only say that it yields better training results.
|
||||||
|
mu = 0.8
|
||||||
|
else:
|
||||||
|
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
if target_timesteps is not None:
|
||||||
|
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
||||||
|
for timestep in target_timesteps:
|
||||||
|
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||||
|
timesteps[timestep_id] = timestep
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
if special_case == "stage2":
|
||||||
|
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||||
|
elif special_case == "ditilled_stage1":
|
||||||
|
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||||
|
else:
|
||||||
|
dynamic_shift_len = dynamic_shift_len or 4096
|
||||||
|
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||||
|
image_seq_len=dynamic_shift_len,
|
||||||
|
base_seq_len=1024,
|
||||||
|
max_seq_len=4096,
|
||||||
|
base_shift=0.95,
|
||||||
|
max_shift=2.05,
|
||||||
|
)
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1.0 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||||
|
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
def set_training_weight(self):
|
||||||
|
steps = 1000
|
||||||
|
x = self.timesteps
|
||||||
|
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||||
|
y_shifted = y - y.min()
|
||||||
|
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||||
|
if len(self.timesteps) != 1000:
|
||||||
|
# This is an empirical formula.
|
||||||
|
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||||
|
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||||
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||||
|
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
denoising_strength=denoising_strength,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if training:
|
||||||
|
self.set_training_weight()
|
||||||
|
self.training = True
|
||||||
|
else:
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
|
sigma_ = 0
|
||||||
|
else:
|
||||||
|
sigma_ = self.sigmas[timestep_id + 1]
|
||||||
|
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
model_output = (sample - sample_stablized) / sigma
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
sample = (1 - sigma) * original_samples + sigma * noise
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
target = noise - sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||||
|
weights = self.linear_timesteps_weights[timestep_id]
|
||||||
|
return weights
|
||||||
43
diffsynth/diffusion/logger.py
Normal file
43
diffsynth/diffusion/logger.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os, torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLogger:
|
||||||
|
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||||
|
self.output_path = output_path
|
||||||
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||||
|
self.state_dict_converter = state_dict_converter
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||||
|
self.num_steps += 1
|
||||||
|
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||||
|
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, file_name)
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
156
diffsynth/diffusion/loss.py
Normal file
156
diffsynth/diffusion/loss.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from .base_pipeline import BasePipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
noise_pred = noise_pred[:, :, 1:]
|
||||||
|
training_target = training_target[:, :, 1:]
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
# video
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
# audio
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||||
|
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||||
|
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||||
|
loss = loss + loss_audio
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||||
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||||
|
pipe.scheduler.training = True
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||||
|
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||||
|
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryImitationLoss(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(self, device):
|
||||||
|
import lpips # TODO: remove it
|
||||||
|
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
trajectory = [inputs_shared["latents"].clone()]
|
||||||
|
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||||
|
|
||||||
|
trajectory.append(inputs_shared["latents"].clone())
|
||||||
|
return pipe.scheduler.timesteps, trajectory
|
||||||
|
|
||||||
|
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
loss = 0
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
||||||
|
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
|
||||||
|
sigma = pipe.scheduler.sigmas[progress_id]
|
||||||
|
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
||||||
|
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
||||||
|
latents_ = trajectory_teacher[-1]
|
||||||
|
else:
|
||||||
|
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||||
|
latents_ = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
|
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
||||||
|
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
inputs_shared["latents"] = trajectory_teacher[0]
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||||
|
|
||||||
|
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
||||||
|
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
||||||
|
loss = self.loss_fn(image_pred.float(), image_real.float())
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if not self.initialized:
|
||||||
|
self.initialize(pipe.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
pipe.scheduler.set_timesteps(8)
|
||||||
|
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
||||||
|
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||||
|
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||||
|
loss = loss_1 + loss_2
|
||||||
|
return loss
|
||||||
70
diffsynth/diffusion/parsers.py
Normal file
70
diffsynth/diffusion/parsers.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||||
|
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||||
|
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||||
|
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||||
|
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_image_size_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_video_size_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||||
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_model_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||||
|
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||||
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
|
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
||||||
|
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_training_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||||
|
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||||
|
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||||
|
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_output_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||||
|
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||||
|
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_lora_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||||
|
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||||
|
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
||||||
|
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_gradient_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||||
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_general_config(parser: argparse.ArgumentParser):
|
||||||
|
parser = add_dataset_base_config(parser)
|
||||||
|
parser = add_model_config(parser)
|
||||||
|
parser = add_training_config(parser)
|
||||||
|
parser = add_output_config(parser)
|
||||||
|
parser = add_lora_config(parser)
|
||||||
|
parser = add_gradient_config(parser)
|
||||||
|
return parser
|
||||||
72
diffsynth/diffusion/runner.py
Normal file
72
diffsynth/diffusion/runner.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os, torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from .training_module import DiffusionTrainingModule
|
||||||
|
from .logger import ModelLogger
|
||||||
|
|
||||||
|
|
||||||
|
def launch_training_task(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
learning_rate: float = 1e-5,
|
||||||
|
weight_decay: float = 1e-2,
|
||||||
|
num_workers: int = 1,
|
||||||
|
save_steps: int = None,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
args = None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
learning_rate = args.learning_rate
|
||||||
|
weight_decay = args.weight_decay
|
||||||
|
num_workers = args.dataset_num_workers
|
||||||
|
save_steps = args.save_steps
|
||||||
|
num_epochs = args.num_epochs
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
|
for epoch_id in range(num_epochs):
|
||||||
|
for data in tqdm(dataloader):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if dataset.load_from_cache:
|
||||||
|
loss = model({}, inputs=data)
|
||||||
|
else:
|
||||||
|
loss = model(data)
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
|
scheduler.step()
|
||||||
|
if save_steps is None:
|
||||||
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_data_process_task(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
num_workers: int = 8,
|
||||||
|
args = None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
num_workers = args.dataset_num_workers
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
|
model, dataloader = accelerator.prepare(model, dataloader)
|
||||||
|
|
||||||
|
for data_id, data in enumerate(tqdm(dataloader)):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
with torch.no_grad():
|
||||||
|
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||||
|
data = model(data)
|
||||||
|
torch.save(data, save_path)
|
||||||
263
diffsynth/diffusion/training_module.py
Normal file
263
diffsynth/diffusion/training_module.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import torch, json, os
|
||||||
|
from ..core import ModelConfig, load_state_dict
|
||||||
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionTrainingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
for name, model in self.named_children():
|
||||||
|
model.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_modules(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
return trainable_modules
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_param_names(self):
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
return trainable_param_names
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||||
|
if lora_alpha is None:
|
||||||
|
lora_alpha = lora_rank
|
||||||
|
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||||
|
target_modules = target_modules[0]
|
||||||
|
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||||
|
model = inject_adapter_in_model(lora_config, model)
|
||||||
|
if upcast_dtype is not None:
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data = param.to(upcast_dtype)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mapping_lora_state_dict(self, state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||||
|
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||||
|
new_state_dict[key] = value
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||||
|
trainable_param_names = self.trainable_param_names()
|
||||||
|
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||||
|
if remove_prefix is not None:
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith(remove_prefix):
|
||||||
|
name = name[len(remove_prefix):]
|
||||||
|
state_dict_[name] = param
|
||||||
|
state_dict = state_dict_
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||||
|
if data is None:
|
||||||
|
return data
|
||||||
|
elif isinstance(data, torch.Tensor):
|
||||||
|
data = data.to(device)
|
||||||
|
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||||
|
data = data.to(torch_float_dtype)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, tuple):
|
||||||
|
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, list):
|
||||||
|
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
||||||
|
if fp8:
|
||||||
|
return {
|
||||||
|
"offload_dtype": torch.float8_e4m3fn,
|
||||||
|
"offload_device": device,
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": device,
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
elif offload:
|
||||||
|
return {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
"clear_parameters": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
||||||
|
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
||||||
|
offload_models = [] if offload_models is None else offload_models.split(",")
|
||||||
|
model_configs = []
|
||||||
|
if model_paths is not None:
|
||||||
|
model_paths = json.loads(model_paths)
|
||||||
|
for path in model_paths:
|
||||||
|
vram_config = self.parse_vram_config(
|
||||||
|
fp8=path in fp8_models,
|
||||||
|
offload=path in offload_models,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
model_configs.append(ModelConfig(path=path, **vram_config))
|
||||||
|
if model_id_with_origin_paths is not None:
|
||||||
|
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||||
|
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||||
|
vram_config = self.parse_vram_config(
|
||||||
|
fp8=model_id_with_origin_path in fp8_models,
|
||||||
|
offload=model_id_with_origin_path in offload_models,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
||||||
|
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
||||||
|
return model_configs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
||||||
|
if model_id_with_origin_path is None:
|
||||||
|
return default_value
|
||||||
|
elif os.path.exists(model_id_with_origin_path):
|
||||||
|
return ModelConfig(path=model_id_with_origin_path)
|
||||||
|
else:
|
||||||
|
if ":" not in model_id_with_origin_path:
|
||||||
|
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
||||||
|
split_id = model_id_with_origin_path.rfind(":")
|
||||||
|
model_id = model_id_with_origin_path[:split_id]
|
||||||
|
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||||
|
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_detect_lora_target_modules(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
search_for_linear=False,
|
||||||
|
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||||
|
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||||
|
name_prefix="",
|
||||||
|
):
|
||||||
|
lora_target_modules = []
|
||||||
|
if search_for_linear:
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||||
|
lora_target_modules.append(module_name)
|
||||||
|
else:
|
||||||
|
for name, module in model.named_children():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||||
|
module,
|
||||||
|
search_for_linear=block_list_detector(module),
|
||||||
|
linear_detector=linear_detector,
|
||||||
|
block_list_detector=block_list_detector,
|
||||||
|
name_prefix=module_name,
|
||||||
|
)
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||||
|
if lora_target_modules == "":
|
||||||
|
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||||
|
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||||
|
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||||
|
else:
|
||||||
|
lora_target_modules = lora_target_modules.split(",")
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def switch_pipe_to_training_mode(
|
||||||
|
self,
|
||||||
|
pipe,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
# Freeze untrainable models
|
||||||
|
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||||
|
|
||||||
|
# Preset LoRA
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
||||||
|
|
||||||
|
# FP8
|
||||||
|
# FP8 relies on a model-specific memory management scheme.
|
||||||
|
# It is delegated to the subclass.
|
||||||
|
|
||||||
|
# Add LoRA to the base models
|
||||||
|
if lora_base_model is not None and not task.endswith(":data_process"):
|
||||||
|
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
||||||
|
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
||||||
|
return
|
||||||
|
model = self.add_lora_to_model(
|
||||||
|
getattr(pipe, lora_base_model),
|
||||||
|
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
upcast_dtype=pipe.torch_dtype,
|
||||||
|
)
|
||||||
|
if lora_checkpoint is not None:
|
||||||
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
|
if len(load_result[1]) > 0:
|
||||||
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
|
||||||
|
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
||||||
|
models_require_backward = []
|
||||||
|
if trainable_models is not None:
|
||||||
|
models_require_backward += trainable_models.split(",")
|
||||||
|
if lora_base_model is not None:
|
||||||
|
models_require_backward += [lora_base_model]
|
||||||
|
if task.endswith(":data_process"):
|
||||||
|
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
elif task.endswith(":train"):
|
||||||
|
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||||
|
controlnet_keys_map = (
|
||||||
|
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
||||||
|
("controlnet_", "controlnet_inputs"),
|
||||||
|
)
|
||||||
|
controlnet_inputs = {}
|
||||||
|
for extra_input in extra_inputs:
|
||||||
|
for prefix, name in controlnet_keys_map:
|
||||||
|
if extra_input.startswith(prefix):
|
||||||
|
if name not in controlnet_inputs:
|
||||||
|
controlnet_inputs[name] = {}
|
||||||
|
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
|
for name, params in controlnet_inputs.items():
|
||||||
|
inputs_shared[name] = [ControlNetInput(**params)]
|
||||||
|
return inputs_shared
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import repeat
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_feat=64, num_grow_ch=32):
|
|
||||||
super(ResidualDenseBlock, self).__init__()
|
|
||||||
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
|
||||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_feat, num_grow_ch=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.rdb1(x)
|
|
||||||
out = self.rdb2(out)
|
|
||||||
out = self.rdb3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
|
||||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
|
||||||
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
# upsample
|
|
||||||
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
||||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
feat = x
|
|
||||||
feat = self.conv_first(feat)
|
|
||||||
body_feat = self.conv_body(self.body(feat))
|
|
||||||
feat = feat + body_feat
|
|
||||||
# upsample
|
|
||||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
||||||
feat = self.lrelu(self.conv_up1(feat))
|
|
||||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
||||||
feat = self.lrelu(self.conv_up2(feat))
|
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
|
||||||
return out
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return RRDBNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN(torch.nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
|
||||||
|
|
||||||
def process_image(self, image):
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process_images(self, images):
|
|
||||||
images = [self.process_image(image) for image in images]
|
|
||||||
images = torch.stack(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
def decode_images(self, images):
|
|
||||||
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
||||||
images = [Image.fromarray(image) for image in images]
|
|
||||||
return images
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
is_single_image = True
|
|
||||||
else:
|
|
||||||
is_single_image = False
|
|
||||||
|
|
||||||
# Preprocess
|
|
||||||
input_tensor = self.process_images(images)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(
|
|
||||||
device=self.model.conv_first.weight.device,
|
|
||||||
dtype=self.model.conv_first.weight.dtype)
|
|
||||||
batch_output_tensor = self.model(batch_input_tensor)
|
|
||||||
output_tensor.append(batch_output_tensor.cpu())
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0)
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(output_tensor)
|
|
||||||
if is_single_image:
|
|
||||||
output_images = output_images[0]
|
|
||||||
return output_images
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from .runners.fast import TableManager, PyramidPatchMatcher
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
import cupy as cp
|
|
||||||
|
|
||||||
|
|
||||||
class FastBlendSmoother:
|
|
||||||
def __init__(self):
|
|
||||||
self.batch_size = 8
|
|
||||||
self.window_size = 64
|
|
||||||
self.ebsynth_config = {
|
|
||||||
"minimum_patch_size": 5,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": 5,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": 10.0,
|
|
||||||
"initialize": "identity",
|
|
||||||
"tracking_window_size": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
# TODO: fetch GPU ID from model_manager
|
|
||||||
return FastBlendSmoother()
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
|
||||||
frames_guide = [np.array(frame) for frame in frames_guide]
|
|
||||||
frames_style = [np.array(frame) for frame in frames_style]
|
|
||||||
table_manager = TableManager()
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# left part
|
|
||||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
|
||||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
||||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
|
||||||
# right part
|
|
||||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
|
||||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
||||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
|
||||||
# merge
|
|
||||||
frames = []
|
|
||||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
||||||
weight_m = -1
|
|
||||||
weight = weight_l + weight_m + weight_r
|
|
||||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
||||||
frames.append(frame)
|
|
||||||
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
|
||||||
frames = self.run(
|
|
||||||
original_frames, rendered_frames,
|
|
||||||
self.batch_size, self.window_size, self.ebsynth_config
|
|
||||||
)
|
|
||||||
mempool = cp.get_default_memory_pool()
|
|
||||||
pinned_mempool = cp.get_default_pinned_memory_pool()
|
|
||||||
mempool.free_all_blocks()
|
|
||||||
pinned_mempool.free_all_blocks()
|
|
||||||
return frames
|
|
||||||
@@ -1,397 +0,0 @@
|
|||||||
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
|
||||||
from .data import VideoData, get_video_fps, save_video, search_for_images
|
|
||||||
import os
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
|
||||||
frames_guide = VideoData(video_guide, video_guide_folder)
|
|
||||||
frames_style = VideoData(video_style, video_style_folder)
|
|
||||||
message = ""
|
|
||||||
if len(frames_guide) < len(frames_style):
|
|
||||||
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
|
||||||
frames_style.set_length(len(frames_guide))
|
|
||||||
elif len(frames_guide) > len(frames_style):
|
|
||||||
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
|
||||||
frames_guide.set_length(len(frames_style))
|
|
||||||
height_guide, width_guide = frames_guide.shape()
|
|
||||||
height_style, width_style = frames_style.shape()
|
|
||||||
if height_guide != height_style or width_guide != width_style:
|
|
||||||
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
|
||||||
frames_style.set_shape(height_guide, width_guide)
|
|
||||||
return frames_guide, frames_style, message
|
|
||||||
|
|
||||||
|
|
||||||
def smooth_video(
|
|
||||||
video_guide,
|
|
||||||
video_guide_folder,
|
|
||||||
video_style,
|
|
||||||
video_style_folder,
|
|
||||||
mode,
|
|
||||||
window_size,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize,
|
|
||||||
progress = None,
|
|
||||||
):
|
|
||||||
# input
|
|
||||||
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
|
||||||
if len(message) > 0:
|
|
||||||
print(message)
|
|
||||||
# output
|
|
||||||
if output_path == "":
|
|
||||||
if video_style is None:
|
|
||||||
output_path = os.path.join(video_style_folder, "output")
|
|
||||||
else:
|
|
||||||
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("No valid output_path. Your video will be saved here:", output_path)
|
|
||||||
elif not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("Your video will be saved here:", output_path)
|
|
||||||
frames_path = os.path.join(output_path, "frames")
|
|
||||||
video_path = os.path.join(output_path, "video.mp4")
|
|
||||||
os.makedirs(frames_path, exist_ok=True)
|
|
||||||
# process
|
|
||||||
if mode == "Fast" or mode == "Balanced":
|
|
||||||
tracking_window_size = 0
|
|
||||||
ebsynth_config = {
|
|
||||||
"minimum_patch_size": minimum_patch_size,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": num_iter,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": guide_weight,
|
|
||||||
"initialize": initialize,
|
|
||||||
"tracking_window_size": tracking_window_size,
|
|
||||||
}
|
|
||||||
if mode == "Fast":
|
|
||||||
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
elif mode == "Balanced":
|
|
||||||
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
elif mode == "Accurate":
|
|
||||||
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
# output
|
|
||||||
try:
|
|
||||||
fps = int(fps)
|
|
||||||
except:
|
|
||||||
fps = get_video_fps(video_style) if video_style is not None else 30
|
|
||||||
print("Fps:", fps)
|
|
||||||
print("Saving video...")
|
|
||||||
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
|
||||||
print("Success!")
|
|
||||||
print("Your frames are here:", frames_path)
|
|
||||||
print("Your video is here:", video_path)
|
|
||||||
return output_path, fps, video_path
|
|
||||||
|
|
||||||
|
|
||||||
class KeyFrameMatcher:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def extract_number_from_filename(self, file_name):
|
|
||||||
result = []
|
|
||||||
number = -1
|
|
||||||
for i in file_name:
|
|
||||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
||||||
if number == -1:
|
|
||||||
number = 0
|
|
||||||
number = number*10 + ord(i) - ord("0")
|
|
||||||
else:
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
number = -1
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
result = tuple(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def extract_number_from_filenames(self, file_names):
|
|
||||||
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
|
||||||
min_length = min(len(i) for i in numbers)
|
|
||||||
for i in range(min_length-1, -1, -1):
|
|
||||||
if len(set(number[i] for number in numbers))==len(file_names):
|
|
||||||
return [number[i] for number in numbers]
|
|
||||||
return list(range(len(file_names)))
|
|
||||||
|
|
||||||
def match_using_filename(self, file_names_a, file_names_b):
|
|
||||||
file_names_b_set = set(file_names_b)
|
|
||||||
matched_file_name = []
|
|
||||||
for file_name in file_names_a:
|
|
||||||
if file_name not in file_names_b_set:
|
|
||||||
matched_file_name.append(None)
|
|
||||||
else:
|
|
||||||
matched_file_name.append(file_name)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
def match_using_numbers(self, file_names_a, file_names_b):
|
|
||||||
numbers_a = self.extract_number_from_filenames(file_names_a)
|
|
||||||
numbers_b = self.extract_number_from_filenames(file_names_b)
|
|
||||||
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
|
||||||
matched_file_name = []
|
|
||||||
for number in numbers_a:
|
|
||||||
if number in numbers_b_dict:
|
|
||||||
matched_file_name.append(numbers_b_dict[number])
|
|
||||||
else:
|
|
||||||
matched_file_name.append(None)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
def match_filenames(self, file_names_a, file_names_b):
|
|
||||||
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
|
||||||
if sum([i is not None for i in matched_file_name]) > 0:
|
|
||||||
return matched_file_name
|
|
||||||
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
|
|
||||||
def detect_frames(frames_path, keyframes_path):
|
|
||||||
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
|
||||||
return "Please input the directory of guide video and rendered frames"
|
|
||||||
elif not os.path.exists(frames_path):
|
|
||||||
return "Please input the directory of guide video"
|
|
||||||
elif not os.path.exists(keyframes_path):
|
|
||||||
return "Please input the directory of rendered frames"
|
|
||||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
|
||||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
|
||||||
if len(frames)==0:
|
|
||||||
return f"No images detected in {frames_path}"
|
|
||||||
if len(keyframes)==0:
|
|
||||||
return f"No images detected in {keyframes_path}"
|
|
||||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
|
||||||
max_filename_length = max([len(i) for i in frames])
|
|
||||||
if sum([i is not None for i in matched_keyframes])==0:
|
|
||||||
message = ""
|
|
||||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
|
||||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
|
||||||
message += "--> No matched keyframes\n"
|
|
||||||
else:
|
|
||||||
message = ""
|
|
||||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
|
||||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
|
||||||
if matched_keyframe is None:
|
|
||||||
message += "--> [to be rendered]\n"
|
|
||||||
else:
|
|
||||||
message += f"--> {matched_keyframe}\n"
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_for_interpolating(frames_path, keyframes_path):
|
|
||||||
# search for images
|
|
||||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
|
||||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
|
||||||
# match frames
|
|
||||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
|
||||||
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
|
||||||
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
|
||||||
frames_guide = VideoData(None, frames_path)
|
|
||||||
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
|
||||||
# match shape
|
|
||||||
message = ""
|
|
||||||
height_guide, width_guide = frames_guide.shape()
|
|
||||||
height_style, width_style = frames_style.shape()
|
|
||||||
if height_guide != height_style or width_guide != width_style:
|
|
||||||
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
|
||||||
frames_style.set_shape(height_guide, width_guide)
|
|
||||||
return frames_guide, frames_style, index_style, message
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate_video(
|
|
||||||
frames_path,
|
|
||||||
keyframes_path,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize,
|
|
||||||
progress = None,
|
|
||||||
):
|
|
||||||
# input
|
|
||||||
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
|
||||||
if len(message) > 0:
|
|
||||||
print(message)
|
|
||||||
# output
|
|
||||||
if output_path == "":
|
|
||||||
output_path = os.path.join(keyframes_path, "output")
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("No valid output_path. Your video will be saved here:", output_path)
|
|
||||||
elif not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("Your video will be saved here:", output_path)
|
|
||||||
output_frames_path = os.path.join(output_path, "frames")
|
|
||||||
output_video_path = os.path.join(output_path, "video.mp4")
|
|
||||||
os.makedirs(output_frames_path, exist_ok=True)
|
|
||||||
# process
|
|
||||||
ebsynth_config = {
|
|
||||||
"minimum_patch_size": minimum_patch_size,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": num_iter,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": guide_weight,
|
|
||||||
"initialize": initialize,
|
|
||||||
"tracking_window_size": tracking_window_size
|
|
||||||
}
|
|
||||||
if len(index_style)==1:
|
|
||||||
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
|
||||||
else:
|
|
||||||
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
|
||||||
try:
|
|
||||||
fps = int(fps)
|
|
||||||
except:
|
|
||||||
fps = 30
|
|
||||||
print("Fps:", fps)
|
|
||||||
print("Saving video...")
|
|
||||||
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
|
||||||
print("Success!")
|
|
||||||
print("Your frames are here:", output_frames_path)
|
|
||||||
print("Your video is here:", video_path)
|
|
||||||
return output_path, fps, video_path
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tabs():
|
|
||||||
with gr.Blocks(analytics_enabled=False) as ui_component:
|
|
||||||
with gr.Tab("Blend"):
|
|
||||||
gr.Markdown("""
|
|
||||||
# Blend
|
|
||||||
|
|
||||||
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
|
||||||
""")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab("Guide video"):
|
|
||||||
video_guide = gr.Video(label="Guide video")
|
|
||||||
with gr.Tab("Guide video (images format)"):
|
|
||||||
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab("Style video"):
|
|
||||||
video_style = gr.Video(label="Style video")
|
|
||||||
with gr.Tab("Style video (images format)"):
|
|
||||||
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
|
||||||
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
|
||||||
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
|
||||||
btn = gr.Button(value="Blend")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("# Settings")
|
|
||||||
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
|
||||||
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
|
||||||
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
|
||||||
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
|
||||||
gr.Markdown("## Advanced Settings")
|
|
||||||
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
|
||||||
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
|
||||||
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
|
||||||
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("""
|
|
||||||
# Reference
|
|
||||||
|
|
||||||
* Output directory: the directory to save the video.
|
|
||||||
* Inference mode
|
|
||||||
|
|
||||||
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
|
||||||
|-|-|-|-|-|-|
|
|
||||||
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
|
||||||
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
|
||||||
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
|
||||||
|
|
||||||
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
|
||||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
|
||||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
|
||||||
* Advanced settings
|
|
||||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
|
||||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
|
||||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
|
||||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
|
||||||
""")
|
|
||||||
btn.click(
|
|
||||||
smooth_video,
|
|
||||||
inputs=[
|
|
||||||
video_guide,
|
|
||||||
video_guide_folder,
|
|
||||||
video_style,
|
|
||||||
video_style_folder,
|
|
||||||
mode,
|
|
||||||
window_size,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize
|
|
||||||
],
|
|
||||||
outputs=[output_path, fps, video_output]
|
|
||||||
)
|
|
||||||
with gr.Tab("Interpolate"):
|
|
||||||
gr.Markdown("""
|
|
||||||
# Interpolate
|
|
||||||
|
|
||||||
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
|
||||||
""")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
|
||||||
with gr.Row():
|
|
||||||
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
|
||||||
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
|
||||||
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
|
||||||
with gr.Column():
|
|
||||||
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
|
||||||
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
|
||||||
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
|
||||||
btn_ = gr.Button(value="Interpolate")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("# Settings")
|
|
||||||
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
|
||||||
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
|
||||||
gr.Markdown("## Advanced Settings")
|
|
||||||
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
|
||||||
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
|
||||||
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
|
||||||
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("""
|
|
||||||
# Reference
|
|
||||||
|
|
||||||
* Output directory: the directory to save the video.
|
|
||||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
|
||||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
|
||||||
* Advanced settings
|
|
||||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
|
||||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
|
||||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
|
||||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
|
||||||
""")
|
|
||||||
btn_.click(
|
|
||||||
interpolate_video,
|
|
||||||
inputs=[
|
|
||||||
video_guide_folder_,
|
|
||||||
rendered_keyframes_,
|
|
||||||
output_path_,
|
|
||||||
fps_,
|
|
||||||
batch_size_,
|
|
||||||
tracking_window_size_,
|
|
||||||
minimum_patch_size_,
|
|
||||||
num_iter_,
|
|
||||||
guide_weight_,
|
|
||||||
initialize_,
|
|
||||||
],
|
|
||||||
outputs=[output_path_, fps_, video_output_]
|
|
||||||
)
|
|
||||||
|
|
||||||
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
import cupy as cp
|
|
||||||
|
|
||||||
remapping_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void remap(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source_style,
|
|
||||||
const int* nnf,
|
|
||||||
float* target_style
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
|
||||||
const int min_px = x < r ? -x : -r;
|
|
||||||
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
|
||||||
const int min_py = y < r ? -y : -r;
|
|
||||||
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
|
||||||
int num = 0;
|
|
||||||
for (int px = min_px; px <= max_px; px++){
|
|
||||||
for (int py = min_py; py <= max_py; py++){
|
|
||||||
const int nid = (x + px) * width + y + py;
|
|
||||||
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
|
||||||
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
|
||||||
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
|
||||||
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
|
||||||
num++;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
target_style[z + pid * channel + c] /= num;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
''', 'remap')
|
|
||||||
|
|
||||||
|
|
||||||
patch_error_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void patch_error(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source,
|
|
||||||
const int* nnf,
|
|
||||||
const float* target,
|
|
||||||
float* error
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
|
||||||
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
|
||||||
float e = 0;
|
|
||||||
for (int px = -r; px <= r; px++){
|
|
||||||
for (int py = -r; py <= r; py++){
|
|
||||||
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
|
||||||
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
|
||||||
e += diff * diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error[blockIdx.z * height * width + x * width + y] = e;
|
|
||||||
}
|
|
||||||
''', 'patch_error')
|
|
||||||
|
|
||||||
|
|
||||||
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void pairwise_patch_error(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source_a,
|
|
||||||
const int* nnf_a,
|
|
||||||
const float* source_b,
|
|
||||||
const int* nnf_b,
|
|
||||||
float* error
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
|
||||||
const int x_a = nnf_a[z_nnf + 0];
|
|
||||||
const int y_a = nnf_a[z_nnf + 1];
|
|
||||||
const int x_b = nnf_b[z_nnf + 0];
|
|
||||||
const int y_b = nnf_b[z_nnf + 1];
|
|
||||||
float e = 0;
|
|
||||||
for (int px = -r; px <= r; px++){
|
|
||||||
for (int py = -r; py <= r; py++){
|
|
||||||
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
|
||||||
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
|
||||||
e += diff * diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error[blockIdx.z * height * width + x * width + y] = e;
|
|
||||||
}
|
|
||||||
''', 'pairwise_patch_error')
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
import imageio, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def read_video(file_name):
|
|
||||||
reader = imageio.get_reader(file_name)
|
|
||||||
video = []
|
|
||||||
for frame in reader:
|
|
||||||
frame = np.array(frame)
|
|
||||||
video.append(frame)
|
|
||||||
reader.close()
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_fps(file_name):
|
|
||||||
reader = imageio.get_reader(file_name)
|
|
||||||
fps = reader.get_meta_data()["fps"]
|
|
||||||
reader.close()
|
|
||||||
return fps
|
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames_path, video_path, num_frames, fps):
|
|
||||||
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
|
||||||
for i in range(num_frames):
|
|
||||||
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
|
||||||
writer.append_data(frame)
|
|
||||||
writer.close()
|
|
||||||
return video_path
|
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryVideo:
|
|
||||||
def __init__(self, file_name):
|
|
||||||
self.reader = imageio.get_reader(file_name)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.reader.count_frames()
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return np.array(self.reader.get_data(item))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.reader.close()
|
|
||||||
|
|
||||||
|
|
||||||
def split_file_name(file_name):
|
|
||||||
result = []
|
|
||||||
number = -1
|
|
||||||
for i in file_name:
|
|
||||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
||||||
if number == -1:
|
|
||||||
number = 0
|
|
||||||
number = number*10 + ord(i) - ord("0")
|
|
||||||
else:
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
number = -1
|
|
||||||
result.append(i)
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
result = tuple(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_images(folder):
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
|
||||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
|
||||||
file_list = [i[1] for i in sorted(file_list)]
|
|
||||||
file_list = [os.path.join(folder, i) for i in file_list]
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def read_images(folder):
|
|
||||||
file_list = search_for_images(folder)
|
|
||||||
frames = [np.array(Image.open(i)) for i in file_list]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryImageFolder:
|
|
||||||
def __init__(self, folder, file_list=None):
|
|
||||||
if file_list is None:
|
|
||||||
self.file_list = search_for_images(folder)
|
|
||||||
else:
|
|
||||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.file_list)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return np.array(Image.open(self.file_list[item]))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VideoData:
|
|
||||||
def __init__(self, video_file, image_folder, **kwargs):
|
|
||||||
if video_file is not None:
|
|
||||||
self.data_type = "video"
|
|
||||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
|
||||||
elif image_folder is not None:
|
|
||||||
self.data_type = "images"
|
|
||||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("Cannot open video or image folder")
|
|
||||||
self.length = None
|
|
||||||
self.height = None
|
|
||||||
self.width = None
|
|
||||||
|
|
||||||
def raw_data(self):
|
|
||||||
frames = []
|
|
||||||
for i in range(self.__len__()):
|
|
||||||
frames.append(self.__getitem__(i))
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def set_length(self, length):
|
|
||||||
self.length = length
|
|
||||||
|
|
||||||
def set_shape(self, height, width):
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
if self.length is None:
|
|
||||||
return len(self.data)
|
|
||||||
else:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def shape(self):
|
|
||||||
if self.height is not None and self.width is not None:
|
|
||||||
return self.height, self.width
|
|
||||||
else:
|
|
||||||
height, width, _ = self.__getitem__(0).shape
|
|
||||||
return height, width
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
frame = self.data.__getitem__(item)
|
|
||||||
height, width, _ = frame.shape
|
|
||||||
if self.height is not None and self.width is not None:
|
|
||||||
if self.height != height or self.width != width:
|
|
||||||
frame = Image.fromarray(frame).resize((self.width, self.height))
|
|
||||||
frame = np.array(frame)
|
|
||||||
return frame
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pass
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
|
||||||
import numpy as np
|
|
||||||
import cupy as cp
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMatcher:
|
|
||||||
def __init__(
|
|
||||||
self, height, width, channel, minimum_patch_size,
|
|
||||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
||||||
random_search_steps=3, random_search_range=4,
|
|
||||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
||||||
tracking_window_size=0
|
|
||||||
):
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.channel = channel
|
|
||||||
self.minimum_patch_size = minimum_patch_size
|
|
||||||
self.threads_per_block = threads_per_block
|
|
||||||
self.num_iter = num_iter
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.guide_weight = guide_weight
|
|
||||||
self.random_search_steps = random_search_steps
|
|
||||||
self.random_search_range = random_search_range
|
|
||||||
self.use_mean_target_style = use_mean_target_style
|
|
||||||
self.use_pairwise_patch_error = use_pairwise_patch_error
|
|
||||||
self.tracking_window_size = tracking_window_size
|
|
||||||
|
|
||||||
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
|
||||||
self.pad_size = self.patch_size_list[0] // 2
|
|
||||||
self.grid = (
|
|
||||||
(height + threads_per_block - 1) // threads_per_block,
|
|
||||||
(width + threads_per_block - 1) // threads_per_block
|
|
||||||
)
|
|
||||||
self.block = (threads_per_block, threads_per_block)
|
|
||||||
|
|
||||||
def pad_image(self, image):
|
|
||||||
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
|
||||||
|
|
||||||
def unpad_image(self, image):
|
|
||||||
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
|
||||||
|
|
||||||
def apply_nnf_to_image(self, nnf, source):
|
|
||||||
batch_size = source.shape[0]
|
|
||||||
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
|
||||||
remapping_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
|
||||||
)
|
|
||||||
return target
|
|
||||||
|
|
||||||
def get_patch_error(self, source, nnf, target):
|
|
||||||
batch_size = source.shape[0]
|
|
||||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
||||||
patch_error_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
|
||||||
)
|
|
||||||
return error
|
|
||||||
|
|
||||||
def get_pairwise_patch_error(self, source, nnf):
|
|
||||||
batch_size = source.shape[0]//2
|
|
||||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
||||||
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
|
||||||
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
|
||||||
pairwise_patch_error_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
|
||||||
)
|
|
||||||
error = error.repeat(2, axis=0)
|
|
||||||
return error
|
|
||||||
|
|
||||||
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
|
||||||
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
|
||||||
if self.use_mean_target_style:
|
|
||||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
||||||
target_style = target_style.mean(axis=0, keepdims=True)
|
|
||||||
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
|
||||||
if self.use_pairwise_patch_error:
|
|
||||||
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
|
||||||
else:
|
|
||||||
error_style = self.get_patch_error(source_style, nnf, target_style)
|
|
||||||
error = error_guide * self.guide_weight + error_style
|
|
||||||
return error
|
|
||||||
|
|
||||||
def clamp_bound(self, nnf):
|
|
||||||
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
|
||||||
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def random_step(self, nnf, r):
|
|
||||||
batch_size = nnf.shape[0]
|
|
||||||
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
|
||||||
upd_nnf = self.clamp_bound(nnf + step)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def neighboor_step(self, nnf, d):
|
|
||||||
if d==0:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
|
||||||
upd_nnf[:, :, :, 0] += 1
|
|
||||||
elif d==1:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
|
||||||
upd_nnf[:, :, :, 1] += 1
|
|
||||||
elif d==2:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
|
||||||
upd_nnf[:, :, :, 0] -= 1
|
|
||||||
elif d==3:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
|
||||||
upd_nnf[:, :, :, 1] -= 1
|
|
||||||
upd_nnf = self.clamp_bound(upd_nnf)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def shift_nnf(self, nnf, d):
|
|
||||||
if d>0:
|
|
||||||
d = min(nnf.shape[0], d)
|
|
||||||
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
||||||
else:
|
|
||||||
d = max(-nnf.shape[0], d)
|
|
||||||
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def track_step(self, nnf, d):
|
|
||||||
if self.use_pairwise_patch_error:
|
|
||||||
upd_nnf = cp.zeros_like(nnf)
|
|
||||||
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
|
||||||
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
|
||||||
else:
|
|
||||||
upd_nnf = self.shift_nnf(nnf, d)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def C(self, n, m):
|
|
||||||
# not used
|
|
||||||
c = 1
|
|
||||||
for i in range(1, n+1):
|
|
||||||
c *= i
|
|
||||||
for i in range(1, m+1):
|
|
||||||
c //= i
|
|
||||||
for i in range(1, n-m+1):
|
|
||||||
c //= i
|
|
||||||
return c
|
|
||||||
|
|
||||||
def bezier_step(self, nnf, r):
|
|
||||||
# not used
|
|
||||||
n = r * 2 - 1
|
|
||||||
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
|
||||||
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
|
||||||
if d>0:
|
|
||||||
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
||||||
elif d<0:
|
|
||||||
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
||||||
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
|
||||||
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
|
||||||
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
|
||||||
upd_idx = (upd_err < err)
|
|
||||||
nnf[upd_idx] = upd_nnf[upd_idx]
|
|
||||||
err[upd_idx] = upd_err[upd_idx]
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for d in cp.random.permutation(4):
|
|
||||||
upd_nnf = self.neighboor_step(nnf, d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for i in range(self.random_search_steps):
|
|
||||||
upd_nnf = self.random_step(nnf, self.random_search_range)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for d in range(1, self.tracking_window_size + 1):
|
|
||||||
upd_nnf = self.track_step(nnf, d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
upd_nnf = self.track_step(nnf, -d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
source_guide = self.pad_image(source_guide)
|
|
||||||
target_guide = self.pad_image(target_guide)
|
|
||||||
source_style = self.pad_image(source_style)
|
|
||||||
for it in range(self.num_iter):
|
|
||||||
self.patch_size = self.patch_size_list[it]
|
|
||||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
||||||
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
|
||||||
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
|
||||||
return nnf, target_style
|
|
||||||
|
|
||||||
|
|
||||||
class PyramidPatchMatcher:
|
|
||||||
def __init__(
|
|
||||||
self, image_height, image_width, channel, minimum_patch_size,
|
|
||||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
||||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
||||||
tracking_window_size=0,
|
|
||||||
initialize="identity"
|
|
||||||
):
|
|
||||||
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
|
||||||
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
|
||||||
self.pyramid_heights = []
|
|
||||||
self.pyramid_widths = []
|
|
||||||
self.patch_matchers = []
|
|
||||||
self.minimum_patch_size = minimum_patch_size
|
|
||||||
self.num_iter = num_iter
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.initialize = initialize
|
|
||||||
for level in range(self.pyramid_level):
|
|
||||||
height = image_height//(2**(self.pyramid_level - 1 - level))
|
|
||||||
width = image_width//(2**(self.pyramid_level - 1 - level))
|
|
||||||
self.pyramid_heights.append(height)
|
|
||||||
self.pyramid_widths.append(width)
|
|
||||||
self.patch_matchers.append(PatchMatcher(
|
|
||||||
height, width, channel, minimum_patch_size=minimum_patch_size,
|
|
||||||
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
|
||||||
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
|
||||||
tracking_window_size=tracking_window_size
|
|
||||||
))
|
|
||||||
|
|
||||||
def resample_image(self, images, level):
|
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
||||||
images = images.get()
|
|
||||||
images_resample = []
|
|
||||||
for image in images:
|
|
||||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
|
||||||
images_resample.append(image_resample)
|
|
||||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
|
||||||
return images_resample
|
|
||||||
|
|
||||||
def initialize_nnf(self, batch_size):
|
|
||||||
if self.initialize == "random":
|
|
||||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
||||||
nnf = cp.stack([
|
|
||||||
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
|
||||||
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
|
||||||
], axis=3)
|
|
||||||
elif self.initialize == "identity":
|
|
||||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
||||||
nnf = cp.stack([
|
|
||||||
cp.repeat(cp.arange(height), width).reshape(height, width),
|
|
||||||
cp.tile(cp.arange(width), height).reshape(height, width)
|
|
||||||
], axis=2)
|
|
||||||
nnf = cp.stack([nnf] * batch_size)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def update_nnf(self, nnf, level):
|
|
||||||
# upscale
|
|
||||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
|
||||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
|
||||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
|
||||||
# check if scale is 2
|
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
||||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
|
||||||
nnf = nnf.get().astype(np.float32)
|
|
||||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
|
||||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
|
||||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def apply_nnf_to_image(self, nnf, image):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
image = self.patch_matchers[-1].pad_image(image)
|
|
||||||
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def estimate_nnf(self, source_guide, target_guide, source_style):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
if not isinstance(source_guide, cp.ndarray):
|
|
||||||
source_guide = cp.array(source_guide, dtype=cp.float32)
|
|
||||||
if not isinstance(target_guide, cp.ndarray):
|
|
||||||
target_guide = cp.array(target_guide, dtype=cp.float32)
|
|
||||||
if not isinstance(source_style, cp.ndarray):
|
|
||||||
source_style = cp.array(source_style, dtype=cp.float32)
|
|
||||||
for level in range(self.pyramid_level):
|
|
||||||
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
|
||||||
source_guide_ = self.resample_image(source_guide, level)
|
|
||||||
target_guide_ = self.resample_image(target_guide, level)
|
|
||||||
source_style_ = self.resample_image(source_style, level)
|
|
||||||
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
|
||||||
source_guide_, target_guide_, source_style_, nnf
|
|
||||||
)
|
|
||||||
return nnf.get(), target_style.get()
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
from .accurate import AccurateModeRunner
|
|
||||||
from .fast import FastModeRunner
|
|
||||||
from .balanced import BalancedModeRunner
|
|
||||||
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class AccurateModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
use_mean_target_style=True,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# run
|
|
||||||
n = len(frames_style)
|
|
||||||
for target in tqdm(range(n), desc=desc):
|
|
||||||
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
|
||||||
remapped_frames = []
|
|
||||||
for i in range(l, r, batch_size):
|
|
||||||
j = min(i + batch_size, r)
|
|
||||||
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
|
||||||
target_guide = np.stack([frames_guide[target]] * (j - i))
|
|
||||||
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
remapped_frames.append(target_style)
|
|
||||||
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
if save_path is not None:
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class BalancedModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# tasks
|
|
||||||
n = len(frames_style)
|
|
||||||
tasks = []
|
|
||||||
for target in range(n):
|
|
||||||
for source in range(target - window_size, target + window_size + 1):
|
|
||||||
if source >= 0 and source < n and source != target:
|
|
||||||
tasks.append((source, target))
|
|
||||||
# run
|
|
||||||
frames = [(None, 1) for i in range(n)]
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
|
||||||
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for (source, target), result in zip(tasks_batch, target_style):
|
|
||||||
frame, weight = frames[target]
|
|
||||||
if frame is None:
|
|
||||||
frame = frames_style[target]
|
|
||||||
frames[target] = (
|
|
||||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
||||||
weight + 1
|
|
||||||
)
|
|
||||||
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
if save_path is not None:
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
frames[target] = (None, 1)
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import functools, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class TableManager:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def task_list(self, n):
|
|
||||||
tasks = []
|
|
||||||
max_level = 1
|
|
||||||
while (1<<max_level)<=n:
|
|
||||||
max_level += 1
|
|
||||||
for i in range(n):
|
|
||||||
j = i
|
|
||||||
for level in range(max_level):
|
|
||||||
if i&(1<<level):
|
|
||||||
continue
|
|
||||||
j |= 1<<level
|
|
||||||
if j>=n:
|
|
||||||
break
|
|
||||||
meta_data = {
|
|
||||||
"source": i,
|
|
||||||
"target": j,
|
|
||||||
"level": level + 1
|
|
||||||
}
|
|
||||||
tasks.append(meta_data)
|
|
||||||
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
|
||||||
n = len(frames_guide)
|
|
||||||
tasks = self.task_list(n)
|
|
||||||
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
||||||
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for task, result in zip(tasks_batch, target_style):
|
|
||||||
target, level = task["target"], task["level"]
|
|
||||||
if len(remapping_table[target])==level:
|
|
||||||
remapping_table[target].append((result, 1))
|
|
||||||
else:
|
|
||||||
frame, weight = remapping_table[target][level]
|
|
||||||
remapping_table[target][level] = (
|
|
||||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
||||||
weight + 1
|
|
||||||
)
|
|
||||||
return remapping_table
|
|
||||||
|
|
||||||
def remapping_table_to_blending_table(self, table):
|
|
||||||
for i in range(len(table)):
|
|
||||||
for j in range(1, len(table[i])):
|
|
||||||
frame_1, weight_1 = table[i][j-1]
|
|
||||||
frame_2, weight_2 = table[i][j]
|
|
||||||
frame = (frame_1 + frame_2) / 2
|
|
||||||
weight = weight_1 + weight_2
|
|
||||||
table[i][j] = (frame, weight)
|
|
||||||
return table
|
|
||||||
|
|
||||||
def tree_query(self, leftbound, rightbound):
|
|
||||||
node_list = []
|
|
||||||
node_index = rightbound
|
|
||||||
while node_index>=leftbound:
|
|
||||||
node_level = 0
|
|
||||||
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
|
||||||
node_level += 1
|
|
||||||
node_list.append((node_index, node_level))
|
|
||||||
node_index -= 1<<node_level
|
|
||||||
return node_list
|
|
||||||
|
|
||||||
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
|
||||||
n = len(blending_table)
|
|
||||||
tasks = []
|
|
||||||
frames_result = []
|
|
||||||
for target in range(n):
|
|
||||||
node_list = self.tree_query(max(target-window_size, 0), target)
|
|
||||||
for source, level in node_list:
|
|
||||||
if source!=target:
|
|
||||||
meta_data = {
|
|
||||||
"source": source,
|
|
||||||
"target": target,
|
|
||||||
"level": level
|
|
||||||
}
|
|
||||||
tasks.append(meta_data)
|
|
||||||
else:
|
|
||||||
frames_result.append(blending_table[target][level])
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
||||||
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for task, frame_2 in zip(tasks_batch, target_style):
|
|
||||||
source, target, level = task["source"], task["target"], task["level"]
|
|
||||||
frame_1, weight_1 = frames_result[target]
|
|
||||||
weight_2 = blending_table[source][level][1]
|
|
||||||
weight = weight_1 + weight_2
|
|
||||||
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
|
||||||
frames_result[target] = (frame, weight)
|
|
||||||
return frames_result
|
|
||||||
|
|
||||||
|
|
||||||
class FastModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
|
||||||
frames_guide = frames_guide.raw_data()
|
|
||||||
frames_style = frames_style.raw_data()
|
|
||||||
table_manager = TableManager()
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# left part
|
|
||||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
|
||||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
||||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
|
||||||
# right part
|
|
||||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
|
||||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
||||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
|
||||||
# merge
|
|
||||||
frames = []
|
|
||||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
||||||
weight_m = -1
|
|
||||||
weight = weight_l + weight_m + weight_r
|
|
||||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
||||||
frames.append(frame)
|
|
||||||
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
|
||||||
if save_path is not None:
|
|
||||||
for target, frame in enumerate(frames):
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_index_dict(self, index_style):
|
|
||||||
index_dict = {}
|
|
||||||
for i, index in enumerate(index_style):
|
|
||||||
index_dict[index] = i
|
|
||||||
return index_dict
|
|
||||||
|
|
||||||
def get_weight(self, l, m, r):
|
|
||||||
weight_l, weight_r = abs(m - r), abs(m - l)
|
|
||||||
if weight_l + weight_r == 0:
|
|
||||||
weight_l, weight_r = 0.5, 0.5
|
|
||||||
else:
|
|
||||||
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
|
||||||
return weight_l, weight_r
|
|
||||||
|
|
||||||
def get_task_group(self, index_style, n):
|
|
||||||
task_group = []
|
|
||||||
index_style = sorted(index_style)
|
|
||||||
# first frame
|
|
||||||
if index_style[0]>0:
|
|
||||||
tasks = []
|
|
||||||
for m in range(index_style[0]):
|
|
||||||
tasks.append((index_style[0], m, index_style[0]))
|
|
||||||
task_group.append(tasks)
|
|
||||||
# middle frames
|
|
||||||
for l, r in zip(index_style[:-1], index_style[1:]):
|
|
||||||
tasks = []
|
|
||||||
for m in range(l, r):
|
|
||||||
tasks.append((l, m, r))
|
|
||||||
task_group.append(tasks)
|
|
||||||
# last frame
|
|
||||||
tasks = []
|
|
||||||
for m in range(index_style[-1], n):
|
|
||||||
tasks.append((index_style[-1], m, index_style[-1]))
|
|
||||||
task_group.append(tasks)
|
|
||||||
return task_group
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
use_mean_target_style=False,
|
|
||||||
use_pairwise_patch_error=True,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# task
|
|
||||||
index_dict = self.get_index_dict(index_style)
|
|
||||||
task_group = self.get_task_group(index_style, len(frames_guide))
|
|
||||||
# run
|
|
||||||
for tasks in task_group:
|
|
||||||
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide, target_guide, source_style = [], [], []
|
|
||||||
for l, m, r in tasks_batch:
|
|
||||||
# l -> m
|
|
||||||
source_guide.append(frames_guide[l])
|
|
||||||
target_guide.append(frames_guide[m])
|
|
||||||
source_style.append(frames_style[index_dict[l]])
|
|
||||||
# r -> m
|
|
||||||
source_guide.append(frames_guide[r])
|
|
||||||
target_guide.append(frames_guide[m])
|
|
||||||
source_style.append(frames_style[index_dict[r]])
|
|
||||||
source_guide = np.stack(source_guide)
|
|
||||||
target_guide = np.stack(target_guide)
|
|
||||||
source_style = np.stack(source_style)
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
if save_path is not None:
|
|
||||||
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
|
||||||
weight_l, weight_r = self.get_weight(l, m, r)
|
|
||||||
frame = frame_l * weight_l + frame_r * weight_r
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationModeSingleFrameRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
||||||
# check input
|
|
||||||
tracking_window_size = ebsynth_config["tracking_window_size"]
|
|
||||||
if tracking_window_size * 2 >= batch_size:
|
|
||||||
raise ValueError("batch_size should be larger than track_window_size * 2")
|
|
||||||
frame_style = frames_style[0]
|
|
||||||
frame_guide = frames_guide[index_style[0]]
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frame_style.shape[0],
|
|
||||||
image_width=frame_style.shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# run
|
|
||||||
frame_id, n = 0, len(frames_guide)
|
|
||||||
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
|
||||||
if i + batch_size > n:
|
|
||||||
l, r = max(n - batch_size, 0), n
|
|
||||||
else:
|
|
||||||
l, r = i, i + batch_size
|
|
||||||
source_guide = np.stack([frame_guide] * (r-l))
|
|
||||||
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
|
||||||
source_style = np.stack([frame_style] * (r-l))
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for i, frame in zip(range(l, r), target_style):
|
|
||||||
if i==frame_id:
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
|
||||||
frame_id += 1
|
|
||||||
if r < n and r-frame_id <= tracking_window_size:
|
|
||||||
break
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def warp(tenInput, tenFlow, device):
|
|
||||||
backwarp_tenGrid = {}
|
|
||||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
|
||||||
if k not in backwarp_tenGrid:
|
|
||||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
|
||||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
|
||||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
|
||||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
|
||||||
backwarp_tenGrid[k] = torch.cat(
|
|
||||||
[tenHorizontal, tenVertical], 1).to(device)
|
|
||||||
|
|
||||||
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
|
||||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
|
||||||
|
|
||||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
|
||||||
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
|
||||||
|
|
||||||
|
|
||||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, bias=True),
|
|
||||||
nn.PReLU(out_planes)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IFBlock(nn.Module):
|
|
||||||
def __init__(self, in_planes, c=64):
|
|
||||||
super(IFBlock, self).__init__()
|
|
||||||
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
|
||||||
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
|
||||||
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
|
||||||
|
|
||||||
def forward(self, x, flow, scale=1):
|
|
||||||
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
||||||
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
|
||||||
feat = self.conv0(torch.cat((x, flow), 1))
|
|
||||||
feat = self.convblock0(feat) + feat
|
|
||||||
feat = self.convblock1(feat) + feat
|
|
||||||
feat = self.convblock2(feat) + feat
|
|
||||||
feat = self.convblock3(feat) + feat
|
|
||||||
flow = self.conv1(feat)
|
|
||||||
mask = self.conv2(feat)
|
|
||||||
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
|
||||||
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
||||||
return flow, mask
|
|
||||||
|
|
||||||
|
|
||||||
class IFNet(nn.Module):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(IFNet, self).__init__()
|
|
||||||
self.block0 = IFBlock(7+4, c=90)
|
|
||||||
self.block1 = IFBlock(7+4, c=90)
|
|
||||||
self.block2 = IFBlock(7+4, c=90)
|
|
||||||
self.block_tea = IFBlock(10+4, c=90)
|
|
||||||
|
|
||||||
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
|
||||||
if training == False:
|
|
||||||
channel = x.shape[1] // 2
|
|
||||||
img0 = x[:, :channel]
|
|
||||||
img1 = x[:, channel:]
|
|
||||||
flow_list = []
|
|
||||||
merged = []
|
|
||||||
mask_list = []
|
|
||||||
warped_img0 = img0
|
|
||||||
warped_img1 = img1
|
|
||||||
flow = (x[:, :4]).detach() * 0
|
|
||||||
mask = (x[:, :1]).detach() * 0
|
|
||||||
block = [self.block0, self.block1, self.block2]
|
|
||||||
for i in range(3):
|
|
||||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
|
||||||
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
|
||||||
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
|
||||||
mask = mask + (m0 + (-m1)) / 2
|
|
||||||
mask_list.append(mask)
|
|
||||||
flow_list.append(flow)
|
|
||||||
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
|
||||||
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
|
||||||
merged.append((warped_img0, warped_img1))
|
|
||||||
'''
|
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
||||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
||||||
res = tmp[:, 1:4] * 2 - 1
|
|
||||||
'''
|
|
||||||
for i in range(3):
|
|
||||||
mask_list[i] = torch.sigmoid(mask_list[i])
|
|
||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
|
||||||
return flow_list, mask_list[2], merged
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return IFNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class IFNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
|
|
||||||
class RIFEInterpolater:
|
|
||||||
def __init__(self, model, device="cuda"):
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
# IFNet only does not support float16
|
|
||||||
self.torch_dtype = torch.float32
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
|
||||||
|
|
||||||
def process_image(self, image):
|
|
||||||
width, height = image.size
|
|
||||||
if width % 32 != 0 or height % 32 != 0:
|
|
||||||
width = (width + 31) // 32
|
|
||||||
height = (height + 31) // 32
|
|
||||||
image = image.resize((width, height))
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process_images(self, images):
|
|
||||||
images = [self.process_image(image) for image in images]
|
|
||||||
images = torch.stack(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
def decode_images(self, images):
|
|
||||||
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
||||||
images = [Image.fromarray(image) for image in images]
|
|
||||||
return images
|
|
||||||
|
|
||||||
def add_interpolated_images(self, images, interpolated_images):
|
|
||||||
output_images = []
|
|
||||||
for image, interpolated_image in zip(images, interpolated_images):
|
|
||||||
output_images.append(image)
|
|
||||||
output_images.append(interpolated_image)
|
|
||||||
output_images.append(images[-1])
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate_(self, images, scale=1.0):
|
|
||||||
input_tensor = self.process_images(images)
|
|
||||||
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
|
||||||
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_images = self.decode_images(merged[2].cpu())
|
|
||||||
if output_images[0].size != images[0].size:
|
|
||||||
output_images = [image.resize(images[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
|
||||||
# Preprocess
|
|
||||||
processed_images = self.process_images(images)
|
|
||||||
|
|
||||||
for iter in range(num_iter):
|
|
||||||
# Input
|
|
||||||
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_tensor.append(merged[2].cpu())
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
|
||||||
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
|
||||||
processed_images = torch.stack(processed_images)
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(processed_images)
|
|
||||||
if output_images[0].size != images[0].size:
|
|
||||||
output_images = [image.resize(images[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class RIFESmoother(RIFEInterpolater):
|
|
||||||
def __init__(self, model, device="cuda"):
|
|
||||||
super(RIFESmoother, self).__init__(model, device=device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
|
||||||
|
|
||||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_tensor.append(merged[2].cpu())
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
|
||||||
# Preprocess
|
|
||||||
processed_images = self.process_images(rendered_frames)
|
|
||||||
|
|
||||||
for iter in range(num_iter):
|
|
||||||
# Input
|
|
||||||
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Blend
|
|
||||||
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
|
||||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Add to frames
|
|
||||||
processed_images[1:-1] = output_tensor
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(processed_images)
|
|
||||||
if output_images[0].size != rendered_frames[0].size:
|
|
||||||
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .model_manager import *
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
def low_version_attention(query, key, value, attn_bias=None):
|
|
||||||
scale = 1 / query.shape[-1] ** 0.5
|
|
||||||
query = query * scale
|
|
||||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
|
||||||
if attn_bias is not None:
|
|
||||||
attn = attn + attn_bias
|
|
||||||
attn = attn.softmax(-1)
|
|
||||||
return attn @ value
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
|
||||||
super().__init__()
|
|
||||||
dim_inner = head_dim * num_heads
|
|
||||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
|
|
||||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
|
||||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
||||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
||||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
|
||||||
|
|
||||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
|
||||||
batch_size = q.shape[0]
|
|
||||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
|
|
||||||
batch_size = encoder_hidden_states.shape[0]
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(encoder_hidden_states)
|
|
||||||
v = self.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if qkv_preprocessor is not None:
|
|
||||||
q, k, v = qkv_preprocessor(q, k, v)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
if ipadapter_kwargs is not None:
|
|
||||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(encoder_hidden_states)
|
|
||||||
v = self.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
|
||||||
else:
|
|
||||||
import xformers.ops as xops
|
|
||||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
|
||||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
||||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
|
||||||
@@ -1,408 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from .sd3_dit import TimestepEmbeddings
|
|
||||||
from .attention import Attention
|
|
||||||
from .utils import load_state_dict_from_folder
|
|
||||||
from .tiler import TileWorker2Dto3D
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogPatchify(torch.nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.proj(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogAdaLayerNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_cond, single=False):
|
|
||||||
super().__init__()
|
|
||||||
self.single = single
|
|
||||||
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, prompt_emb, emb):
|
|
||||||
emb = self.linear(torch.nn.functional.silu(emb))
|
|
||||||
if self.single:
|
|
||||||
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
|
||||||
return hidden_states
|
|
||||||
else:
|
|
||||||
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
|
||||||
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
|
||||||
return hidden_states, prompt_emb, gate_a, gate_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiTBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_cond, num_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
|
||||||
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
|
||||||
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
|
||||||
|
|
||||||
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(self, x, freqs_cis):
|
|
||||||
cos, sin = freqs_cis # [S, D]
|
|
||||||
cos = cos[None, None]
|
|
||||||
sin = sin[None, None]
|
|
||||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
|
||||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
|
||||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
|
||||||
# Attention
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
|
||||||
hidden_states, prompt_emb, time_emb
|
|
||||||
)
|
|
||||||
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
attention_io = self.attn1(
|
|
||||||
attention_io,
|
|
||||||
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
|
||||||
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
|
||||||
|
|
||||||
# Feed forward
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
|
||||||
hidden_states, prompt_emb, time_emb
|
|
||||||
)
|
|
||||||
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
ff_io = self.ff(ff_io)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
|
||||||
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
|
||||||
|
|
||||||
return hidden_states, prompt_emb
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiT(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.patchify = CogPatchify(16, 3072, 2)
|
|
||||||
self.time_embedder = TimestepEmbeddings(3072, 512)
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
|
||||||
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
|
||||||
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
|
||||||
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
|
||||||
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
|
||||||
tw = tgt_width
|
|
||||||
th = tgt_height
|
|
||||||
h, w = src
|
|
||||||
r = h / w
|
|
||||||
if r > (th / tw):
|
|
||||||
resize_height = th
|
|
||||||
resize_width = int(round(th / h * w))
|
|
||||||
else:
|
|
||||||
resize_width = tw
|
|
||||||
resize_height = int(round(tw / w * h))
|
|
||||||
|
|
||||||
crop_top = int(round((th - resize_height) / 2.0))
|
|
||||||
crop_left = int(round((tw - resize_width) / 2.0))
|
|
||||||
|
|
||||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
|
||||||
|
|
||||||
|
|
||||||
def get_3d_rotary_pos_embed(
|
|
||||||
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
|
||||||
):
|
|
||||||
start, stop = crops_coords
|
|
||||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
|
||||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
|
||||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
|
||||||
|
|
||||||
# Compute dimensions for each axis
|
|
||||||
dim_t = embed_dim // 4
|
|
||||||
dim_h = embed_dim // 8 * 3
|
|
||||||
dim_w = embed_dim // 8 * 3
|
|
||||||
|
|
||||||
# Temporal frequencies
|
|
||||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
|
||||||
grid_t = torch.from_numpy(grid_t).float()
|
|
||||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
|
||||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Spatial frequencies for height and width
|
|
||||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
|
||||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
|
||||||
grid_h = torch.from_numpy(grid_h).float()
|
|
||||||
grid_w = torch.from_numpy(grid_w).float()
|
|
||||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
|
||||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
|
||||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
|
||||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Broadcast and concatenate tensors along specified dimension
|
|
||||||
def broadcast(tensors, dim=-1):
|
|
||||||
num_tensors = len(tensors)
|
|
||||||
shape_lens = {len(t.shape) for t in tensors}
|
|
||||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
|
||||||
shape_len = list(shape_lens)[0]
|
|
||||||
dim = (dim + shape_len) if dim < 0 else dim
|
|
||||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
|
||||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
|
||||||
assert all(
|
|
||||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
|
||||||
), "invalid dimensions for broadcastable concatenation"
|
|
||||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
|
||||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
|
||||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
|
||||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
|
||||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
|
||||||
return torch.cat(tensors, dim=dim)
|
|
||||||
|
|
||||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
|
||||||
|
|
||||||
t, h, w, d = freqs.shape
|
|
||||||
freqs = freqs.view(t * h * w, d)
|
|
||||||
|
|
||||||
# Generate sine and cosine components
|
|
||||||
sin = freqs.sin()
|
|
||||||
cos = freqs.cos()
|
|
||||||
|
|
||||||
if use_real:
|
|
||||||
return cos, sin
|
|
||||||
else:
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_rotary_positional_embeddings(
|
|
||||||
self,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
num_frames: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
grid_height = height // 2
|
|
||||||
grid_width = width // 2
|
|
||||||
base_size_width = 720 // (8 * 2)
|
|
||||||
base_size_height = 480 // (8 * 2)
|
|
||||||
|
|
||||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
|
||||||
(grid_height, grid_width), base_size_width, base_size_height
|
|
||||||
)
|
|
||||||
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
|
||||||
embed_dim=64,
|
|
||||||
crops_coords=grid_crops_coords,
|
|
||||||
grid_size=(grid_height, grid_width),
|
|
||||||
temporal_size=num_frames,
|
|
||||||
use_real=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
|
||||||
freqs_sin = freqs_sin.to(device=device)
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
|
|
||||||
|
|
||||||
def unpatchify(self, hidden_states, height, width):
|
|
||||||
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, T, H, W, dtype, device, is_bound):
|
|
||||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
|
||||||
border_width = (H + W) // 4
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else t + 1,
|
|
||||||
pad if is_bound[1] else T - t,
|
|
||||||
pad if is_bound[2] else h + 1,
|
|
||||||
pad if is_bound[3] else H - h,
|
|
||||||
pad if is_bound[4] else w + 1,
|
|
||||||
pad if is_bound[5] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride):
|
|
||||||
for w in range(0, W, tile_stride):
|
|
||||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size, w + tile_size
|
|
||||||
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
|
||||||
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in tasks:
|
|
||||||
mask = self.build_mask(
|
|
||||||
value.shape[2], (hr-hl), (wr-wl),
|
|
||||||
hidden_states.dtype, hidden_states.device,
|
|
||||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
|
||||||
)
|
|
||||||
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
|
||||||
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
|
||||||
weight[:, :, :, hl:hr, wl:wr] += mask
|
|
||||||
value = value / weight
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
|
||||||
if tiled:
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
|
||||||
model_input=hidden_states,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
|
||||||
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
|
||||||
)
|
|
||||||
num_frames, height, width = hidden_states.shape[-3:]
|
|
||||||
if image_rotary_emb is None:
|
|
||||||
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
|
||||||
hidden_states = self.patchify(hidden_states)
|
|
||||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
|
||||||
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
|
||||||
model = CogDiT().to(torch_dtype)
|
|
||||||
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
|
||||||
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"patch_embed.proj.weight": "patchify.proj.weight",
|
|
||||||
"patch_embed.proj.bias": "patchify.proj.bias",
|
|
||||||
"patch_embed.text_proj.weight": "context_embedder.weight",
|
|
||||||
"patch_embed.text_proj.bias": "context_embedder.bias",
|
|
||||||
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
|
|
||||||
"norm_final.weight": "norm_final.weight",
|
|
||||||
"norm_final.bias": "norm_final.bias",
|
|
||||||
"norm_out.linear.weight": "norm_out.linear.weight",
|
|
||||||
"norm_out.linear.bias": "norm_out.linear.bias",
|
|
||||||
"norm_out.norm.weight": "norm_out.norm.weight",
|
|
||||||
"norm_out.norm.bias": "norm_out.norm.bias",
|
|
||||||
"proj_out.weight": "proj_out.weight",
|
|
||||||
"proj_out.bias": "proj_out.bias",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.linear.weight": "norm1.linear.weight",
|
|
||||||
"norm1.linear.bias": "norm1.linear.bias",
|
|
||||||
"norm1.norm.weight": "norm1.norm.weight",
|
|
||||||
"norm1.norm.bias": "norm1.norm.bias",
|
|
||||||
"attn1.norm_q.weight": "norm_q.weight",
|
|
||||||
"attn1.norm_q.bias": "norm_q.bias",
|
|
||||||
"attn1.norm_k.weight": "norm_k.weight",
|
|
||||||
"attn1.norm_k.bias": "norm_k.bias",
|
|
||||||
"attn1.to_q.weight": "attn1.to_q.weight",
|
|
||||||
"attn1.to_q.bias": "attn1.to_q.bias",
|
|
||||||
"attn1.to_k.weight": "attn1.to_k.weight",
|
|
||||||
"attn1.to_k.bias": "attn1.to_k.bias",
|
|
||||||
"attn1.to_v.weight": "attn1.to_v.weight",
|
|
||||||
"attn1.to_v.bias": "attn1.to_v.bias",
|
|
||||||
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
|
||||||
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
|
||||||
"norm2.linear.weight": "norm2.linear.weight",
|
|
||||||
"norm2.linear.bias": "norm2.linear.bias",
|
|
||||||
"norm2.norm.weight": "norm2.norm.weight",
|
|
||||||
"norm2.norm.bias": "norm2.norm.bias",
|
|
||||||
"ff.net.0.proj.weight": "ff.0.weight",
|
|
||||||
"ff.net.0.proj.bias": "ff.0.bias",
|
|
||||||
"ff.net.2.weight": "ff.2.weight",
|
|
||||||
"ff.net.2.bias": "ff.2.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
if name == "patch_embed.proj.weight":
|
|
||||||
param = param.unsqueeze(2)
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
names = name.split(".")
|
|
||||||
if names[0] == "transformer_blocks":
|
|
||||||
suffix = ".".join(names[2:])
|
|
||||||
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,518 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from .tiler import TileWorker2Dto3D
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample3D(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
stride: int = 2,
|
|
||||||
padding: int = 0,
|
|
||||||
compress_time: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.compress_time = compress_time
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.compress_time:
|
|
||||||
batch_size, channels, frames, height, width = x.shape
|
|
||||||
|
|
||||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
|
||||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
|
||||||
|
|
||||||
if x.shape[-1] % 2 == 1:
|
|
||||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
|
||||||
if x_rest.shape[-1] > 0:
|
|
||||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
|
||||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
|
||||||
|
|
||||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
|
||||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
|
||||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
|
||||||
else:
|
|
||||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
|
||||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
|
||||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
|
||||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
|
||||||
|
|
||||||
# Pad the tensor
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
batch_size, channels, frames, height, width = x.shape
|
|
||||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
|
||||||
x = self.conv(x)
|
|
||||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
|
||||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample3D(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
stride: int = 1,
|
|
||||||
padding: int = 1,
|
|
||||||
compress_time: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.compress_time = compress_time
|
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.compress_time:
|
|
||||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
|
||||||
# split first frame
|
|
||||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
|
||||||
|
|
||||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
|
|
||||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
|
|
||||||
x_first = x_first[:, :, None, :, :]
|
|
||||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
|
||||||
elif inputs.shape[2] > 1:
|
|
||||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
|
||||||
else:
|
|
||||||
inputs = inputs.squeeze(2)
|
|
||||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
|
||||||
inputs = inputs[:, :, None, :, :]
|
|
||||||
else:
|
|
||||||
# only interpolate 2D
|
|
||||||
b, c, t, h, w = inputs.shape
|
|
||||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
||||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
|
||||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
b, c, t, h, w = inputs.shape
|
|
||||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
||||||
inputs = self.conv(inputs)
|
|
||||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXSpatialNorm3D(torch.nn.Module):
|
|
||||||
def __init__(self, f_channels, zq_channels, groups):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
|
||||||
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
|
||||||
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
|
||||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
|
||||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
|
||||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
|
||||||
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
|
||||||
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
|
|
||||||
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
|
|
||||||
zq = torch.cat([z_first, z_rest], dim=2)
|
|
||||||
else:
|
|
||||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
|
|
||||||
|
|
||||||
norm_f = self.norm_layer(f)
|
|
||||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
|
||||||
return new_f
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Resnet3DBlock(torch.nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
|
|
||||||
super().__init__()
|
|
||||||
self.nonlinearity = torch.nn.SiLU()
|
|
||||||
if spatial_norm_dim is None:
|
|
||||||
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
|
||||||
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
|
||||||
else:
|
|
||||||
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
|
|
||||||
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
|
|
||||||
|
|
||||||
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
if in_channels != out_channels:
|
|
||||||
if use_conv_shortcut:
|
|
||||||
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
|
||||||
else:
|
|
||||||
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
|
||||||
else:
|
|
||||||
self.conv_shortcut = lambda x: x
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, zq):
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.conv1(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.conv2(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + self.conv_shortcut(residual)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CachedConv3d(torch.nn.Conv3d):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
|
||||||
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.cached_tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.cached_tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
|
|
||||||
if use_cache:
|
|
||||||
if self.cached_tensor is None:
|
|
||||||
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
|
|
||||||
input = torch.concat([self.cached_tensor, input], dim=2)
|
|
||||||
self.cached_tensor = input[:, :, -2:]
|
|
||||||
return super().forward(input)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEDecoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.7
|
|
||||||
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Upsample3D(512, 512, compress_time=True),
|
|
||||||
Resnet3DBlock(512, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Upsample3D(256, 256, compress_time=True),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Upsample3D(256, 256, compress_time=False),
|
|
||||||
Resnet3DBlock(256, 128, 16, 32),
|
|
||||||
Resnet3DBlock(128, 128, 16, 32),
|
|
||||||
Resnet3DBlock(128, 128, 16, 32),
|
|
||||||
Resnet3DBlock(128, 128, 16, 32),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, sample):
|
|
||||||
sample = sample / self.scaling_factor
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
hidden_states = block(hidden_states, sample)
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, sample)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
|
||||||
if tiled:
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.decode_small_video(x),
|
|
||||||
model_input=sample,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
|
||||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
|
||||||
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
|
|
||||||
progress_bar=progress_bar
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.decode_small_video(sample)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_small_video(self, sample):
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
computation_device = self.conv_in.weight.device
|
|
||||||
computation_dtype = self.conv_in.weight.dtype
|
|
||||||
value = []
|
|
||||||
for i in range(T//2):
|
|
||||||
tl = i*2 + T%2 - (T%2 and i==0)
|
|
||||||
tr = i*2 + 2 + T%2
|
|
||||||
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
|
|
||||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
|
||||||
value.append(model_output)
|
|
||||||
value = torch.concat(value, dim=2)
|
|
||||||
for name, module in self.named_modules():
|
|
||||||
if isinstance(module, CachedConv3d):
|
|
||||||
module.clear_cache()
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogVAEDecoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEEncoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.7
|
|
||||||
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
Resnet3DBlock(128, 128, None, 32),
|
|
||||||
Resnet3DBlock(128, 128, None, 32),
|
|
||||||
Resnet3DBlock(128, 128, None, 32),
|
|
||||||
Downsample3D(128, 128, compress_time=True),
|
|
||||||
Resnet3DBlock(128, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Downsample3D(256, 256, compress_time=True),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Downsample3D(256, 256, compress_time=False),
|
|
||||||
Resnet3DBlock(256, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, sample):
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
hidden_states = block(hidden_states, sample)
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)[:, :16]
|
|
||||||
hidden_states = hidden_states * self.scaling_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
|
||||||
if tiled:
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.encode_small_video(x),
|
|
||||||
model_input=sample,
|
|
||||||
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
|
|
||||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
|
||||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
|
||||||
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
|
|
||||||
progress_bar=progress_bar
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.encode_small_video(sample)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_small_video(self, sample):
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
computation_device = self.conv_in.weight.device
|
|
||||||
computation_dtype = self.conv_in.weight.dtype
|
|
||||||
value = []
|
|
||||||
for i in range(T//8):
|
|
||||||
t = i*8 + T%2 - (T%2 and i==0)
|
|
||||||
t_ = i*8 + 8 + T%2
|
|
||||||
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
|
|
||||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
|
||||||
value.append(model_output)
|
|
||||||
value = torch.concat(value, dim=2)
|
|
||||||
for name, module in self.named_modules():
|
|
||||||
if isinstance(module, CachedConv3d):
|
|
||||||
module.clear_cache()
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogVAEEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEEncoderStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"encoder.conv_in.conv.weight": "conv_in.weight",
|
|
||||||
"encoder.conv_in.conv.bias": "conv_in.bias",
|
|
||||||
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
|
|
||||||
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
|
|
||||||
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
|
|
||||||
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
|
|
||||||
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
|
|
||||||
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
|
|
||||||
"encoder.norm_out.weight": "norm_out.weight",
|
|
||||||
"encoder.norm_out.bias": "norm_out.bias",
|
|
||||||
"encoder.conv_out.conv.weight": "conv_out.weight",
|
|
||||||
"encoder.conv_out.conv.bias": "conv_out.bias",
|
|
||||||
}
|
|
||||||
prefix_dict = {
|
|
||||||
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
|
|
||||||
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
|
|
||||||
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
|
|
||||||
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
|
|
||||||
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
|
|
||||||
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
|
|
||||||
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
|
|
||||||
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
|
|
||||||
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
|
|
||||||
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
|
|
||||||
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
|
|
||||||
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
|
|
||||||
"encoder.mid_block.resnets.0.": "blocks.15.",
|
|
||||||
"encoder.mid_block.resnets.1.": "blocks.16.",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
|
||||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
|
||||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
|
||||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
|
||||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
|
||||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
|
||||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
|
||||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
|
||||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
|
||||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
|
||||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
|
||||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
|
||||||
"conv1.conv.weight": "conv1.weight",
|
|
||||||
"conv1.conv.bias": "conv1.bias",
|
|
||||||
"conv2.conv.weight": "conv2.weight",
|
|
||||||
"conv2.conv.bias": "conv2.bias",
|
|
||||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
|
||||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
|
||||||
"norm1.weight": "norm1.weight",
|
|
||||||
"norm1.bias": "norm1.bias",
|
|
||||||
"norm2.weight": "norm2.weight",
|
|
||||||
"norm2.bias": "norm2.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
for prefix in prefix_dict:
|
|
||||||
if name.startswith(prefix):
|
|
||||||
suffix = name[len(prefix):]
|
|
||||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEDecoderStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"decoder.conv_in.conv.weight": "conv_in.weight",
|
|
||||||
"decoder.conv_in.conv.bias": "conv_in.bias",
|
|
||||||
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
|
|
||||||
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
|
|
||||||
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
|
|
||||||
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
|
|
||||||
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
|
|
||||||
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
|
|
||||||
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
|
|
||||||
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
|
|
||||||
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
|
|
||||||
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
|
|
||||||
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
|
|
||||||
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
|
|
||||||
"decoder.conv_out.conv.weight": "conv_out.weight",
|
|
||||||
"decoder.conv_out.conv.bias": "conv_out.bias"
|
|
||||||
}
|
|
||||||
prefix_dict = {
|
|
||||||
"decoder.mid_block.resnets.0.": "blocks.0.",
|
|
||||||
"decoder.mid_block.resnets.1.": "blocks.1.",
|
|
||||||
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
|
|
||||||
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
|
|
||||||
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
|
|
||||||
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
|
|
||||||
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
|
|
||||||
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
|
|
||||||
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
|
|
||||||
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
|
|
||||||
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
|
|
||||||
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
|
|
||||||
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
|
|
||||||
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
|
|
||||||
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
|
|
||||||
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
|
|
||||||
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
|
|
||||||
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
|
||||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
|
||||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
|
||||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
|
||||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
|
||||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
|
||||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
|
||||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
|
||||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
|
||||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
|
||||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
|
||||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
|
||||||
"conv1.conv.weight": "conv1.weight",
|
|
||||||
"conv1.conv.bias": "conv1.bias",
|
|
||||||
"conv2.conv.weight": "conv2.weight",
|
|
||||||
"conv2.conv.bias": "conv2.bias",
|
|
||||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
|
||||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
for prefix in prefix_dict:
|
|
||||||
if name.startswith(prefix):
|
|
||||||
suffix = name[len(prefix):]
|
|
||||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
96
diffsynth/models/dinov3_image_encoder.py
Normal file
96
diffsynth/models/dinov3_image_encoder.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||||
|
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = DINOv3ViTConfig(
|
||||||
|
architectures = [
|
||||||
|
"DINOv3ViTModel"
|
||||||
|
],
|
||||||
|
attention_dropout = 0.0,
|
||||||
|
drop_path_rate = 0.0,
|
||||||
|
dtype = "float32",
|
||||||
|
hidden_act = "silu",
|
||||||
|
hidden_size = 4096,
|
||||||
|
image_size = 224,
|
||||||
|
initializer_range = 0.02,
|
||||||
|
intermediate_size = 8192,
|
||||||
|
key_bias = False,
|
||||||
|
layer_norm_eps = 1e-05,
|
||||||
|
layerscale_value = 1.0,
|
||||||
|
mlp_bias = True,
|
||||||
|
model_type = "dinov3_vit",
|
||||||
|
num_attention_heads = 32,
|
||||||
|
num_channels = 3,
|
||||||
|
num_hidden_layers = 40,
|
||||||
|
num_register_tokens = 4,
|
||||||
|
patch_size = 16,
|
||||||
|
pos_embed_jitter = None,
|
||||||
|
pos_embed_rescale = 2.0,
|
||||||
|
pos_embed_shift = None,
|
||||||
|
proj_bias = True,
|
||||||
|
query_bias = False,
|
||||||
|
rope_theta = 100.0,
|
||||||
|
transformers_version = "4.56.1",
|
||||||
|
use_gated_mlp = True,
|
||||||
|
value_bias = False
|
||||||
|
)
|
||||||
|
super().__init__(config)
|
||||||
|
self.processor = DINOv3ViTImageProcessorFast(
|
||||||
|
crop_size = None,
|
||||||
|
data_format = "channels_first",
|
||||||
|
default_to_square = True,
|
||||||
|
device = None,
|
||||||
|
disable_grouping = None,
|
||||||
|
do_center_crop = None,
|
||||||
|
do_convert_rgb = None,
|
||||||
|
do_normalize = True,
|
||||||
|
do_rescale = True,
|
||||||
|
do_resize = True,
|
||||||
|
image_mean = [
|
||||||
|
0.485,
|
||||||
|
0.456,
|
||||||
|
0.406
|
||||||
|
],
|
||||||
|
image_processor_type = "DINOv3ViTImageProcessorFast",
|
||||||
|
image_std = [
|
||||||
|
0.229,
|
||||||
|
0.224,
|
||||||
|
0.225
|
||||||
|
],
|
||||||
|
input_data_format = None,
|
||||||
|
resample = 2,
|
||||||
|
rescale_factor = 0.00392156862745098,
|
||||||
|
return_tensors = None,
|
||||||
|
size = {
|
||||||
|
"height": 224,
|
||||||
|
"width": 224
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||||
|
bool_masked_pos = None
|
||||||
|
head_mask = None
|
||||||
|
|
||||||
|
pixel_values = pixel_values.to(torch_dtype)
|
||||||
|
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
hidden_states = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=layer_head_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = self.norm(hidden_states)
|
||||||
|
pooled_output = sequence_output[:, 0, :]
|
||||||
|
|
||||||
|
return pooled_output
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
import os, shutil
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
from typing import List
|
|
||||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
|
||||||
|
|
||||||
|
|
||||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
file_name = os.path.basename(origin_file_path)
|
|
||||||
if file_name in os.listdir(local_dir):
|
|
||||||
print(f" {file_name} has been already in {local_dir}.")
|
|
||||||
else:
|
|
||||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
|
||||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
|
||||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
|
||||||
if downloaded_file_path != target_file_path:
|
|
||||||
shutil.move(downloaded_file_path, target_file_path)
|
|
||||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
file_name = os.path.basename(origin_file_path)
|
|
||||||
if file_name in os.listdir(local_dir):
|
|
||||||
print(f" {file_name} has been already in {local_dir}.")
|
|
||||||
else:
|
|
||||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
|
||||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
|
||||||
target_file_path = os.path.join(local_dir, file_name)
|
|
||||||
if downloaded_file_path != target_file_path:
|
|
||||||
shutil.move(downloaded_file_path, target_file_path)
|
|
||||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
|
||||||
|
|
||||||
|
|
||||||
Preset_model_website: TypeAlias = Literal[
|
|
||||||
"HuggingFace",
|
|
||||||
"ModelScope",
|
|
||||||
]
|
|
||||||
website_to_preset_models = {
|
|
||||||
"HuggingFace": preset_models_on_huggingface,
|
|
||||||
"ModelScope": preset_models_on_modelscope,
|
|
||||||
}
|
|
||||||
website_to_download_fn = {
|
|
||||||
"HuggingFace": download_from_huggingface,
|
|
||||||
"ModelScope": download_from_modelscope,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_customized_models(
|
|
||||||
model_id,
|
|
||||||
origin_file_path,
|
|
||||||
local_dir,
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
):
|
|
||||||
downloaded_files = []
|
|
||||||
for website in downloading_priority:
|
|
||||||
# Check if the file is downloaded.
|
|
||||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
|
||||||
if file_to_download in downloaded_files:
|
|
||||||
continue
|
|
||||||
# Download
|
|
||||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
downloaded_files.append(file_to_download)
|
|
||||||
return downloaded_files
|
|
||||||
|
|
||||||
|
|
||||||
def download_models(
|
|
||||||
model_id_list: List[Preset_model_id] = [],
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
):
|
|
||||||
print(f"Downloading models: {model_id_list}")
|
|
||||||
downloaded_files = []
|
|
||||||
load_files = []
|
|
||||||
|
|
||||||
for model_id in model_id_list:
|
|
||||||
for website in downloading_priority:
|
|
||||||
if model_id in website_to_preset_models[website]:
|
|
||||||
|
|
||||||
# Parse model metadata
|
|
||||||
model_metadata = website_to_preset_models[website][model_id]
|
|
||||||
if isinstance(model_metadata, list):
|
|
||||||
file_data = model_metadata
|
|
||||||
else:
|
|
||||||
file_data = model_metadata.get("file_list", [])
|
|
||||||
|
|
||||||
# Try downloading the model from this website.
|
|
||||||
model_files = []
|
|
||||||
for model_id, origin_file_path, local_dir in file_data:
|
|
||||||
# Check if the file is downloaded.
|
|
||||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
|
||||||
if file_to_download in downloaded_files:
|
|
||||||
continue
|
|
||||||
# Download
|
|
||||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
downloaded_files.append(file_to_download)
|
|
||||||
model_files.append(file_to_download)
|
|
||||||
|
|
||||||
# If the model is successfully downloaded, break.
|
|
||||||
if len(model_files) > 0:
|
|
||||||
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
|
||||||
model_files = model_metadata["load_path"]
|
|
||||||
load_files.extend(model_files)
|
|
||||||
break
|
|
||||||
|
|
||||||
return load_files
|
|
||||||
1050
diffsynth/models/flux2_dit.py
Normal file
1050
diffsynth/models/flux2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
58
diffsynth/models/flux2_text_encoder.py
Normal file
58
diffsynth/models/flux2_text_encoder.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
||||||
|
def __init__(self):
|
||||||
|
config = Mistral3Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"Mistral3ForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"image_token_index": 10,
|
||||||
|
"model_type": "mistral3",
|
||||||
|
"multimodal_projector_bias": False,
|
||||||
|
"projector_hidden_act": "gelu",
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"text_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 32768,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "mistral",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000000.0,
|
||||||
|
"sliding_window": None,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 131072
|
||||||
|
},
|
||||||
|
"transformers_version": "4.57.1",
|
||||||
|
"vision_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"head_dim": 64,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"image_size": 1540,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"model_type": "pixtral",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"patch_size": 14,
|
||||||
|
"rope_theta": 10000.0
|
||||||
|
},
|
||||||
|
"vision_feature_layer": -1
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
||||||
|
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
||||||
|
|
||||||
2322
diffsynth/models/flux2_vae.py
Normal file
2322
diffsynth/models/flux2_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,62 @@
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||||
from .utils import hash_state_dict_keys, init_weights_on_device
|
# from .utils import hash_state_dict_keys, init_weights_on_device
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||||
|
|
||||||
|
old_register_parameter = torch.nn.Module.register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
old_register_buffer = torch.nn.Module.register_buffer
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||||
|
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||||
|
if buffer is not None:
|
||||||
|
module._buffers[name] = module._buffers[name].to(device)
|
||||||
|
|
||||||
|
def patch_tensor_constructor(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
kwargs["device"] = device
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
if include_buffers:
|
||||||
|
tensor_constructors_to_patch = {
|
||||||
|
torch_function_name: getattr(torch, torch_function_name)
|
||||||
|
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tensor_constructors_to_patch = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.nn.Module.register_parameter = register_empty_parameter
|
||||||
|
if include_buffers:
|
||||||
|
torch.nn.Module.register_buffer = register_empty_buffer
|
||||||
|
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||||
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Module.register_parameter = old_register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
torch.nn.Module.register_buffer = old_register_buffer
|
||||||
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||||
|
setattr(torch, torch_function_name, old_torch_function)
|
||||||
|
|
||||||
class FluxControlNet(torch.nn.Module):
|
class FluxControlNet(torch.nn.Module):
|
||||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||||
@@ -102,9 +155,9 @@ class FluxControlNet(torch.nn.Module):
|
|||||||
return controlnet_res_stack, controlnet_single_res_stack
|
return controlnet_res_stack, controlnet_single_res_stack
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def state_dict_converter():
|
# def state_dict_converter():
|
||||||
return FluxControlNetStateDictConverter()
|
# return FluxControlNetStateDictConverter()
|
||||||
|
|
||||||
def quantize(self):
|
def quantize(self):
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
@@ -318,6 +371,10 @@ class FluxControlNetStateDictConverter:
|
|||||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||||
|
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||||
|
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
||||||
else:
|
else:
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
return state_dict_, extra_kwargs
|
return state_dict_, extra_kwargs
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .tiler import TileWorker
|
|
||||||
from .utils import init_weights_on_device
|
|
||||||
|
|
||||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||||
@@ -269,27 +268,29 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, conditioning):
|
def forward(self, x, conditioning):
|
||||||
emb = self.linear(self.silu(conditioning))
|
emb = self.linear(self.silu(conditioning))
|
||||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
shift, scale = torch.chunk(emb, 2, dim=1)
|
||||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxDiT(torch.nn.Module):
|
class FluxDiT(torch.nn.Module):
|
||||||
def __init__(self, disable_guidance_embedder=False):
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||||
|
|
||||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
|
||||||
|
|
||||||
def patchify(self, hidden_states):
|
def patchify(self, hidden_states):
|
||||||
@@ -319,25 +320,6 @@ class FluxDiT(torch.nn.Module):
|
|||||||
return latent_image_ids
|
return latent_image_ids
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
|
||||||
tile_size=128, tile_stride=64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
|
||||||
hidden_states,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||||
N = len(entity_masks)
|
N = len(entity_masks)
|
||||||
batch_size = entity_masks[0].shape[0]
|
batch_size = entity_masks[0].shape[0]
|
||||||
@@ -373,8 +355,7 @@ class FluxDiT(torch.nn.Module):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||||
repeat_dim = hidden_states.shape[1]
|
|
||||||
max_masks = 0
|
max_masks = 0
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
prompt_embs = [prompt_emb]
|
prompt_embs = [prompt_emb]
|
||||||
@@ -410,330 +391,5 @@ class FluxDiT(torch.nn.Module):
|
|||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if tiled:
|
# (Deprecated) The real forward is in `pipelines.flux_image`.
|
||||||
return self.tiled_forward(
|
return None
|
||||||
hidden_states,
|
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_ids is None:
|
|
||||||
image_ids = self.prepare_image_ids(hidden_states)
|
|
||||||
|
|
||||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
|
||||||
if self.guidance_embedder is not None:
|
|
||||||
guidance = guidance * 1000
|
|
||||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
|
||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
|
||||||
hidden_states = self.patchify(hidden_states)
|
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
|
||||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
|
||||||
else:
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
for block in self.single_blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
|
||||||
|
|
||||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
||||||
hidden_states = self.final_proj_out(hidden_states)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def quantize(self):
|
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_weight(s, input=None, dtype=None, device=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if bias_dtype is None:
|
|
||||||
bias_dtype = dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
bias = None
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
bias = cast_to(s.bias, bias_dtype, device)
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class quantized_layer:
|
|
||||||
class Linear(torch.nn.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self,input,**kwargs):
|
|
||||||
weight,bias= cast_bias_weight(self,input)
|
|
||||||
return torch.nn.functional.linear(input,weight,bias)
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, module):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
|
|
||||||
def forward(self,hidden_states,**kwargs):
|
|
||||||
weight= cast_weight(self.module,hidden_states)
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype) * weight
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def replace_layer(model):
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
|
||||||
new_layer.weight = module.weight
|
|
||||||
if module.bias is not None:
|
|
||||||
new_layer.bias = module.bias
|
|
||||||
# del module
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, RMSNorm):
|
|
||||||
if hasattr(module,"quantized"):
|
|
||||||
continue
|
|
||||||
module.quantized= True
|
|
||||||
new_layer = quantized_layer.RMSNorm(module)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
else:
|
|
||||||
replace_layer(module)
|
|
||||||
|
|
||||||
replace_layer(self)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
global_rename_dict = {
|
|
||||||
"context_embedder": "context_embedder",
|
|
||||||
"x_embedder": "x_embedder",
|
|
||||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
|
||||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
|
||||||
"norm_out.linear": "final_norm_out.linear",
|
|
||||||
"proj_out": "final_proj_out",
|
|
||||||
}
|
|
||||||
rename_dict = {
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
"norm1.linear": "norm1_a.linear",
|
|
||||||
"norm1_context.linear": "norm1_b.linear",
|
|
||||||
"attn.to_q": "attn.a_to_q",
|
|
||||||
"attn.to_k": "attn.a_to_k",
|
|
||||||
"attn.to_v": "attn.a_to_v",
|
|
||||||
"attn.to_out.0": "attn.a_to_out",
|
|
||||||
"attn.add_q_proj": "attn.b_to_q",
|
|
||||||
"attn.add_k_proj": "attn.b_to_k",
|
|
||||||
"attn.add_v_proj": "attn.b_to_v",
|
|
||||||
"attn.to_add_out": "attn.b_to_out",
|
|
||||||
"ff.net.0.proj": "ff_a.0",
|
|
||||||
"ff.net.2": "ff_a.2",
|
|
||||||
"ff_context.net.0.proj": "ff_b.0",
|
|
||||||
"ff_context.net.2": "ff_b.2",
|
|
||||||
"attn.norm_q": "attn.norm_q_a",
|
|
||||||
"attn.norm_k": "attn.norm_k_a",
|
|
||||||
"attn.norm_added_q": "attn.norm_q_b",
|
|
||||||
"attn.norm_added_k": "attn.norm_k_b",
|
|
||||||
}
|
|
||||||
rename_dict_single = {
|
|
||||||
"attn.to_q": "a_to_q",
|
|
||||||
"attn.to_k": "a_to_k",
|
|
||||||
"attn.to_v": "a_to_v",
|
|
||||||
"attn.norm_q": "norm_q_a",
|
|
||||||
"attn.norm_k": "norm_k_a",
|
|
||||||
"norm.linear": "norm.linear",
|
|
||||||
"proj_mlp": "proj_in_besides_attn",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.endswith(".weight") or name.endswith(".bias"):
|
|
||||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
|
||||||
prefix = name[:-len(suffix)]
|
|
||||||
if prefix in global_rename_dict:
|
|
||||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
|
||||||
elif prefix.startswith("transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif prefix.startswith("single_transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "single_blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict_single:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
for name in list(state_dict_.keys()):
|
|
||||||
if ".proj_in_besides_attn." in name:
|
|
||||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
|
||||||
state_dict_[name],
|
|
||||||
], dim=0)
|
|
||||||
state_dict_[name_] = param
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
|
||||||
state_dict_.pop(name)
|
|
||||||
for name in list(state_dict_.keys()):
|
|
||||||
for component in ["a", "b"]:
|
|
||||||
if f".{component}_to_q." in name:
|
|
||||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
|
||||||
], dim=0)
|
|
||||||
state_dict_[name_] = param
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"txt_in.bias": "context_embedder.bias",
|
|
||||||
"txt_in.weight": "context_embedder.weight",
|
|
||||||
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
|
||||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
|
||||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
|
||||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
|
||||||
"final_layer.linear.bias": "final_proj_out.bias",
|
|
||||||
"final_layer.linear.weight": "final_proj_out.weight",
|
|
||||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
|
||||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
|
||||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
|
||||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
|
||||||
"img_in.bias": "x_embedder.bias",
|
|
||||||
"img_in.weight": "x_embedder.weight",
|
|
||||||
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
|
||||||
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
|
||||||
}
|
|
||||||
suffix_rename_dict = {
|
|
||||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
|
||||||
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
|
||||||
"img_attn.proj.bias": "attn.a_to_out.bias",
|
|
||||||
"img_attn.proj.weight": "attn.a_to_out.weight",
|
|
||||||
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
|
||||||
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
|
||||||
"img_mlp.0.bias": "ff_a.0.bias",
|
|
||||||
"img_mlp.0.weight": "ff_a.0.weight",
|
|
||||||
"img_mlp.2.bias": "ff_a.2.bias",
|
|
||||||
"img_mlp.2.weight": "ff_a.2.weight",
|
|
||||||
"img_mod.lin.bias": "norm1_a.linear.bias",
|
|
||||||
"img_mod.lin.weight": "norm1_a.linear.weight",
|
|
||||||
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
|
||||||
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
|
||||||
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
|
||||||
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
|
||||||
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
|
||||||
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
|
||||||
"txt_mlp.0.bias": "ff_b.0.bias",
|
|
||||||
"txt_mlp.0.weight": "ff_b.0.weight",
|
|
||||||
"txt_mlp.2.bias": "ff_b.2.bias",
|
|
||||||
"txt_mlp.2.weight": "ff_b.2.weight",
|
|
||||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
|
||||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
|
||||||
|
|
||||||
"linear1.bias": "to_qkv_mlp.bias",
|
|
||||||
"linear1.weight": "to_qkv_mlp.weight",
|
|
||||||
"linear2.bias": "proj_out.bias",
|
|
||||||
"linear2.weight": "proj_out.weight",
|
|
||||||
"modulation.lin.bias": "norm.linear.bias",
|
|
||||||
"modulation.lin.weight": "norm.linear.weight",
|
|
||||||
"norm.key_norm.scale": "norm_k_a.weight",
|
|
||||||
"norm.query_norm.scale": "norm_q_a.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.startswith("model.diffusion_model."):
|
|
||||||
name = name[len("model.diffusion_model."):]
|
|
||||||
names = name.split(".")
|
|
||||||
if name in rename_dict:
|
|
||||||
rename = rename_dict[name]
|
|
||||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
|
||||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
|
||||||
state_dict_[rename] = param
|
|
||||||
elif names[0] == "double_blocks":
|
|
||||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
elif names[0] == "single_blocks":
|
|
||||||
if ".".join(names[2:]) in suffix_rename_dict:
|
|
||||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
|
||||||
return state_dict_, {"disable_guidance_embedder": True}
|
|
||||||
else:
|
|
||||||
return state_dict_
|
|
||||||
|
|||||||
129
diffsynth/models/flux_infiniteyou.py
Normal file
129
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
def FeedForward(dim, mult=4):
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_tensor(x, heads):
|
||||||
|
bs, length, width = x.shape
|
||||||
|
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
|
x = x.view(bs, length, heads, -1)
|
||||||
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||||
|
x = x.reshape(bs, heads, length, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, latents):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): image features
|
||||||
|
shape (b, n1, D)
|
||||||
|
latent (torch.Tensor): latent features
|
||||||
|
shape (b, n2, D)
|
||||||
|
"""
|
||||||
|
x = self.norm1(x)
|
||||||
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
|
q = self.to_q(latents)
|
||||||
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
q = reshape_tensor(q, self.heads)
|
||||||
|
k = reshape_tensor(k, self.heads)
|
||||||
|
v = reshape_tensor(v, self.heads)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
out = weight @ v
|
||||||
|
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteYouImageProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=1280,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=20,
|
||||||
|
num_queries=8,
|
||||||
|
embedding_dim=512,
|
||||||
|
output_dim=4096,
|
||||||
|
ff_mult=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||||
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim)
|
||||||
|
self.norm_out = nn.LayerNorm(output_dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(x, latents) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
return self.norm_out(latents)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict['image_proj']
|
||||||
@@ -1,9 +1,25 @@
|
|||||||
from .svd_image_encoder import SVDImageEncoder
|
from .general_modules import RMSNorm
|
||||||
from .sd3_dit import RMSNorm
|
from transformers import SiglipVisionModel, SiglipVisionConfig
|
||||||
from transformers import CLIPImageProcessor
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionModelSO400M(SiglipVisionModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = SiglipVisionConfig(
|
||||||
|
hidden_size=1152,
|
||||||
|
image_size=384,
|
||||||
|
intermediate_size=4304,
|
||||||
|
model_type="siglip_vision_model",
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=27,
|
||||||
|
patch_size=14,
|
||||||
|
architectures=["SiglipModel"],
|
||||||
|
initializer_factor=1.0,
|
||||||
|
torch_dtype="float32",
|
||||||
|
transformers_version="4.37.0.dev0"
|
||||||
|
)
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
class MLPProjModel(torch.nn.Module):
|
class MLPProjModel(torch.nn.Module):
|
||||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
521
diffsynth/models/flux_lora_encoder.py
Normal file
521
diffsynth/models/flux_lora_encoder.py
Normal file
@@ -0,0 +1,521 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def low_version_attention(query, key, value, attn_bias=None):
|
||||||
|
scale = 1 / query.shape[-1] ** 0.5
|
||||||
|
query = query * scale
|
||||||
|
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn = attn + attn_bias
|
||||||
|
attn = attn.softmax(-1)
|
||||||
|
return attn @ value
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if qkv_preprocessor is not None:
|
||||||
|
q, k, v = qkv_preprocessor(q, k, v)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
if ipadapter_kwargs is not None:
|
||||||
|
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||||
|
else:
|
||||||
|
import xformers.ops as xops
|
||||||
|
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||||
|
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||||
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||||
|
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
|
||||||
|
def quickGELU(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attn_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
if self.use_quick_gelu:
|
||||||
|
hidden_states = self.quickGELU(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# token_embedding
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# position_embeds (This is a fixed tensor)
|
||||||
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# attn_mask
|
||||||
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||||
|
|
||||||
|
# final_layer_norm
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def attention_mask(self, length):
|
||||||
|
mask = torch.empty(length, length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, input_ids, clip_skip=1):
|
||||||
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||||
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||||
|
if encoder_id + clip_skip == len(self.encoders):
|
||||||
|
break
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||||
|
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||||
|
}
|
||||||
|
attn_rename_dict = {
|
||||||
|
"self_attn.q_proj": "attn.to_q",
|
||||||
|
"self_attn.k_proj": "attn.to_k",
|
||||||
|
"self_attn.v_proj": "attn.to_v",
|
||||||
|
"self_attn.out_proj": "attn.to_out",
|
||||||
|
"layer_norm1": "layer_norm1",
|
||||||
|
"layer_norm2": "layer_norm2",
|
||||||
|
"mlp.fc1": "fc1",
|
||||||
|
"mlp.fc2": "fc2",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
elif name.startswith("text_model.encoder.layers."):
|
||||||
|
param = state_dict[name]
|
||||||
|
names = name.split(".")
|
||||||
|
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||||
|
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, L, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||||
|
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
||||||
|
|
||||||
|
def forward(self, lora_A, lora_B):
|
||||||
|
x = self.x @ lora_A.T @ lora_B.T
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAEmbedder(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
proj_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
||||||
|
if layer_type not in proj_dict:
|
||||||
|
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
||||||
|
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||||
|
|
||||||
|
self.lora_patterns = lora_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||||
|
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix],
|
||||||
|
"type": suffix,
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix],
|
||||||
|
"type": suffix,
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, lora):
|
||||||
|
lora_emb = []
|
||||||
|
for lora_pattern in self.lora_patterns:
|
||||||
|
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||||
|
lora_A = lora[name + ".lora_A.weight"]
|
||||||
|
lora_B = lora[name + ".lora_B.weight"]
|
||||||
|
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||||
|
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||||
|
lora_emb.append(lora_out)
|
||||||
|
lora_emb = torch.concat(lora_emb, dim=1)
|
||||||
|
return lora_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_embeds_per_lora = num_embeds_per_lora
|
||||||
|
# embedder
|
||||||
|
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# special embedding
|
||||||
|
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
||||||
|
self.num_special_embeds = num_special_embeds
|
||||||
|
|
||||||
|
# final layer
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, lora):
|
||||||
|
lora_embeds = self.embedder(lora)
|
||||||
|
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
||||||
|
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds)
|
||||||
|
embeds = embeds[:, :self.num_special_embeds]
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
embeds = self.final_linear(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxLoRAEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAEncoderStateDictConverter:
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
306
diffsynth/models/flux_lora_patcher.py
Normal file
306
diffsynth/models/flux_lora_patcher.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
import torch, math
|
||||||
|
from ..core.loader import load_state_dict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
class GeneralLoRALoader:
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_dict(self, lora_state_dict):
|
||||||
|
lora_name_dict = {}
|
||||||
|
for key in lora_state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
keys = key.split(".")
|
||||||
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
|
keys.pop(-1)
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
updated_num = 0
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in lora_name_dict:
|
||||||
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||||
|
module.load_state_dict(state_dict)
|
||||||
|
updated_num += 1
|
||||||
|
print(f"{updated_num} tensors are updated by LoRA.")
|
||||||
|
|
||||||
|
class FluxLoRALoader(GeneralLoRALoader):
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
self.diffusers_rename_dict = {
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.civitai_rename_dict = {
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
super().load(model, state_dict_lora, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(self,state_dict):
|
||||||
|
|
||||||
|
def guess_block_id(name,model_resource):
|
||||||
|
if model_resource == 'civitai':
|
||||||
|
names = name.split("_")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"_{i}_", "_blockid_")
|
||||||
|
if model_resource == 'diffusers':
|
||||||
|
names = name.split(".")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def guess_resource(state_dict):
|
||||||
|
for k in state_dict:
|
||||||
|
if "lora_unet_" in k:
|
||||||
|
return 'civitai'
|
||||||
|
elif k.startswith("transformer."):
|
||||||
|
return 'diffusers'
|
||||||
|
else:
|
||||||
|
None
|
||||||
|
|
||||||
|
model_resource = guess_resource(state_dict)
|
||||||
|
if model_resource is None:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||||
|
def guess_alpha(state_dict):
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if ".alpha" in name:
|
||||||
|
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||||
|
name_ = name.replace(".alpha", suffix)
|
||||||
|
if name_ in state_dict:
|
||||||
|
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||||
|
lora_alpha = math.sqrt(lora_alpha)
|
||||||
|
return lora_alpha
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
alpha = guess_alpha(state_dict)
|
||||||
|
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
block_id, source_name = guess_block_id(name,model_resource)
|
||||||
|
if alpha != 1:
|
||||||
|
param *= alpha
|
||||||
|
if source_name in rename_dict:
|
||||||
|
target_name = rename_dict[source_name]
|
||||||
|
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||||
|
state_dict_[target_name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
|
||||||
|
if model_resource == 'diffusers':
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
dim = 4
|
||||||
|
if 'lora_A' in name:
|
||||||
|
dim = 1
|
||||||
|
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
|
if 'lora_A' in name:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
elif 'lora_B' in name:
|
||||||
|
d, r = state_dict_[name].shape
|
||||||
|
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||||
|
param[:d, :r] = state_dict_.pop(name)
|
||||||
|
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||||
|
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||||
|
param[3*d:, 3*r:] = mlp
|
||||||
|
else:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
|
state_dict_[name_] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
concat_dim = 0
|
||||||
|
if 'lora_A' in name:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
elif 'lora_B' in name:
|
||||||
|
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||||
|
d, r = origin.shape
|
||||||
|
# print(d, r)
|
||||||
|
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||||
|
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||||
|
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||||
|
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||||
|
else:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class LoraMerger(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.activation = torch.nn.Sigmoid()
|
||||||
|
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs):
|
||||||
|
norm_base_output = self.norm_base(base_output)
|
||||||
|
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||||
|
gate = self.activation(
|
||||||
|
norm_base_output * self.weight_base \
|
||||||
|
+ norm_lora_outputs * self.weight_lora \
|
||||||
|
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||||
|
)
|
||||||
|
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class FluxLoraPatcher(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||||
|
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs, name):
|
||||||
|
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import T5EncoderModel, T5Config
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxTextEncoder2(T5EncoderModel):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, input_ids):
|
|
||||||
outputs = super().forward(input_ids=input_ids)
|
|
||||||
prompt_emb = outputs.last_hidden_state
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxTextEncoder2StateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxTextEncoder2StateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = state_dict
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
112
diffsynth/models/flux_text_encoder_clip.py
Normal file
112
diffsynth/models/flux_text_encoder_clip.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||||
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||||
|
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
|
||||||
|
def quickGELU(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attn_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
if self.use_quick_gelu:
|
||||||
|
hidden_states = self.quickGELU(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoderClip(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# token_embedding
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# position_embeds (This is a fixed tensor)
|
||||||
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# attn_mask
|
||||||
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||||
|
|
||||||
|
# final_layer_norm
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def attention_mask(self, length):
|
||||||
|
mask = torch.empty(length, length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||||
|
embeds = self.token_embedding(input_ids)
|
||||||
|
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||||
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
|
if extra_mask is not None:
|
||||||
|
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||||
|
if encoder_id + clip_skip == len(self.encoders):
|
||||||
|
hidden_states = embeds
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||||
|
return pooled_embeds, hidden_states
|
||||||
43
diffsynth/models/flux_text_encoder_t5.py
Normal file
43
diffsynth/models/flux_text_encoder_t5.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import T5EncoderModel, T5Config
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoderT5(T5EncoderModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = T5Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"T5EncoderModel"
|
||||||
|
],
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 4096,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dense_act_fn": "gelu_new",
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"feed_forward_proj": "gated-gelu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": True,
|
||||||
|
"is_gated_act": True,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": True,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_max_distance": 128,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"transformers_version": "4.57.1",
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 32128
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
outputs = super().forward(input_ids=input_ids)
|
||||||
|
prompt_emb = outputs.last_hidden_state
|
||||||
|
return prompt_emb
|
||||||
@@ -1,303 +1,451 @@
|
|||||||
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
import torch
|
||||||
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEEncoder(SD3VAEEncoder):
|
class TileWorker:
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.3611
|
|
||||||
self.shift_factor = 0.1159
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxVAEEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEDecoder(SD3VAEDecoder):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.3611
|
|
||||||
self.shift_factor = 0.1159
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxVAEDecoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
def mask(self, height, width, border_width):
|
||||||
"encoder.conv_in.bias": "conv_in.bias",
|
# Create a mask with shape (height, width).
|
||||||
"encoder.conv_in.weight": "conv_in.weight",
|
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||||
"encoder.conv_out.bias": "conv_out.bias",
|
x = torch.arange(height).repeat(width, 1).T
|
||||||
"encoder.conv_out.weight": "conv_out.weight",
|
y = torch.arange(width).repeat(height, 1)
|
||||||
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||||
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
mask = (mask / border_width).clip(0, 1)
|
||||||
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
return mask
|
||||||
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
|
||||||
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
|
||||||
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||||
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||||
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
batch_size, channel, _, _ = model_input.shape
|
||||||
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||||
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
unfold_operator = torch.nn.Unfold(
|
||||||
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
kernel_size=(tile_size, tile_size),
|
||||||
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
stride=(tile_stride, tile_stride)
|
||||||
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
)
|
||||||
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
model_input = unfold_operator(model_input)
|
||||||
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||||
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
|
||||||
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
return model_input
|
||||||
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
|
||||||
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
|
||||||
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||||
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
# Call y=forward_fn(x) for each tile
|
||||||
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
tile_num = model_input.shape[-1]
|
||||||
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
model_output_stack = []
|
||||||
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
|
||||||
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
for tile_id in range(0, tile_num, tile_batch_size):
|
||||||
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
|
||||||
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
# process input
|
||||||
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||||
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||||
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||||
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||||
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
|
||||||
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
# process output
|
||||||
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
y = forward_fn(x)
|
||||||
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||||
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||||
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
model_output_stack.append(y)
|
||||||
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
|
||||||
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
model_output = torch.concat(model_output_stack, dim=-1)
|
||||||
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
return model_output
|
||||||
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
|
||||||
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
|
||||||
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
def io_scale(self, model_output, tile_size):
|
||||||
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
# Determine the size modification happened in forward_fn
|
||||||
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
# We only consider the same scale on height and width.
|
||||||
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
io_scale = model_output.shape[2] / tile_size
|
||||||
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
return io_scale
|
||||||
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
|
||||||
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
|
||||||
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
|
||||||
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
|
||||||
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
|
||||||
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
|
||||||
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
|
||||||
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
|
||||||
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
|
||||||
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
|
||||||
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
|
||||||
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
|
||||||
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
|
||||||
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
|
||||||
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
|
||||||
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
|
||||||
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
|
||||||
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
|
||||||
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
|
||||||
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
|
||||||
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
|
||||||
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
|
||||||
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
|
||||||
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
|
||||||
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
|
||||||
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
|
||||||
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
|
||||||
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
|
||||||
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
|
||||||
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
|
||||||
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
|
||||||
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
|
||||||
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
|
||||||
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
|
||||||
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
|
||||||
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
|
||||||
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
|
||||||
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
|
||||||
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
|
||||||
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
|
||||||
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
|
||||||
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
|
||||||
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
|
||||||
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
|
||||||
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
|
||||||
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
|
||||||
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
|
||||||
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
|
||||||
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
|
||||||
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
|
||||||
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
|
||||||
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
|
||||||
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
|
||||||
"encoder.norm_out.bias": "conv_norm_out.bias",
|
|
||||||
"encoder.norm_out.weight": "conv_norm_out.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
|
||||||
param = state_dict[name]
|
|
||||||
if "transformer_blocks" in rename_dict[name]:
|
|
||||||
param = param.squeeze()
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
|
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||||
|
# The reversed function of tile
|
||||||
|
mask = self.mask(tile_size, tile_size, border_width)
|
||||||
|
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||||
|
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||||
|
model_output = model_output * mask
|
||||||
|
|
||||||
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
fold_operator = torch.nn.Fold(
|
||||||
def __init__(self):
|
output_size=(height, width),
|
||||||
pass
|
kernel_size=(tile_size, tile_size),
|
||||||
|
stride=(tile_stride, tile_stride)
|
||||||
|
)
|
||||||
|
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||||
|
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||||
|
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
return model_output
|
||||||
rename_dict = {
|
|
||||||
"decoder.conv_in.bias": "conv_in.bias",
|
|
||||||
"decoder.conv_in.weight": "conv_in.weight",
|
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||||
"decoder.conv_out.bias": "conv_out.bias",
|
# Prepare
|
||||||
"decoder.conv_out.weight": "conv_out.weight",
|
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||||
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
height, width = model_input.shape[2], model_input.shape[3]
|
||||||
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||||
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
|
||||||
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
# tile
|
||||||
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||||
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
|
||||||
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
# inference
|
||||||
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||||
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
|
||||||
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
# resize
|
||||||
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
io_scale = self.io_scale(model_output, tile_size)
|
||||||
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
height, width = int(height*io_scale), int(width*io_scale)
|
||||||
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||||
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
border_width = int(border_width*io_scale)
|
||||||
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
|
||||||
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
# untile
|
||||||
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||||
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
|
||||||
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
# Done!
|
||||||
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||||
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
return model_output
|
||||||
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
|
||||||
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
|
||||||
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
class ConvAttention(torch.nn.Module):
|
||||||
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
|
||||||
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
"decoder.norm_out.bias": "conv_norm_out.bias",
|
super().__init__()
|
||||||
"decoder.norm_out.weight": "conv_norm_out.weight",
|
dim_inner = head_dim * num_heads
|
||||||
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
self.num_heads = num_heads
|
||||||
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
self.head_dim = head_dim
|
||||||
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
|
||||||
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
|
||||||
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||||
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||||
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
|
||||||
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
|
||||||
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
if encoder_hidden_states is None:
|
||||||
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
encoder_hidden_states = hidden_states
|
||||||
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
|
||||||
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
|
||||||
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||||
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
q = self.to_q(conv_input)
|
||||||
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
|
||||||
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
|
||||||
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
k = self.to_k(conv_input)
|
||||||
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
v = self.to_v(conv_input)
|
||||||
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
|
||||||
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
|
||||||
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
|
||||||
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
|
||||||
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
|
||||||
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||||
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
hidden_states = self.to_out(conv_input)
|
||||||
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
|
||||||
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
|
||||||
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
return hidden_states
|
||||||
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
|
||||||
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
|
||||||
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
class Attention(torch.nn.Module):
|
||||||
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
|
||||||
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
super().__init__()
|
||||||
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
dim_inner = head_dim * num_heads
|
||||||
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
self.num_heads = num_heads
|
||||||
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
self.head_dim = head_dim
|
||||||
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
|
||||||
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
|
||||||
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
if encoder_hidden_states is None:
|
||||||
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
encoder_hidden_states = hidden_states
|
||||||
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
|
||||||
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
|
||||||
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
q = self.to_q(hidden_states)
|
||||||
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
k = self.to_k(encoder_hidden_states)
|
||||||
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
v = self.to_v(encoder_hidden_states)
|
||||||
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
|
||||||
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
|
||||||
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
|
||||||
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
hidden_states = self.to_out(hidden_states)
|
||||||
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
|
||||||
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
return hidden_states
|
||||||
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
|
||||||
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
|
||||||
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
class VAEAttentionBlock(torch.nn.Module):
|
||||||
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
|
||||||
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
||||||
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
super().__init__()
|
||||||
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
|
||||||
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
|
||||||
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
if use_conv_attention:
|
||||||
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
self.transformer_blocks = torch.nn.ModuleList([
|
||||||
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
ConvAttention(
|
||||||
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
inner_dim,
|
||||||
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
num_attention_heads,
|
||||||
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
attention_head_dim,
|
||||||
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
bias_q=True,
|
||||||
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
bias_kv=True,
|
||||||
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
bias_out=True
|
||||||
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
)
|
||||||
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
for d in range(num_layers)
|
||||||
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
])
|
||||||
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
else:
|
||||||
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
self.transformer_blocks = torch.nn.ModuleList([
|
||||||
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
Attention(
|
||||||
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
inner_dim,
|
||||||
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
num_attention_heads,
|
||||||
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
attention_head_dim,
|
||||||
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
bias_q=True,
|
||||||
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
bias_kv=True,
|
||||||
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
bias_out=True
|
||||||
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
)
|
||||||
}
|
for d in range(num_layers)
|
||||||
state_dict_ = {}
|
])
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||||
param = state_dict[name]
|
batch, _, height, width = hidden_states.shape
|
||||||
if "transformer_blocks" in rename_dict[name]:
|
residual = hidden_states
|
||||||
param = param.squeeze()
|
|
||||||
state_dict_[rename_dict[name]] = param
|
hidden_states = self.norm(hidden_states)
|
||||||
return state_dict_
|
inner_dim = hidden_states.shape[1]
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
if temb_channels is not None:
|
||||||
|
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.nonlinearity = torch.nn.SiLU()
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
x = hidden_states
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
if time_emb is not None:
|
||||||
|
emb = self.nonlinearity(time_emb)
|
||||||
|
emb = self.time_emb_proj(emb)[:, :, None, None]
|
||||||
|
x = x + emb
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
hidden_states = self.conv_shortcut(hidden_states)
|
||||||
|
hidden_states = hidden_states + x
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class UpSampler(torch.nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class DownSampler(torch.nn.Module):
|
||||||
|
def __init__(self, channels, padding=1, extra_padding=False):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
|
||||||
|
self.extra_padding = extra_padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
if self.extra_padding:
|
||||||
|
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEDecoder(torch.nn.Module):
|
||||||
|
def __init__(self, use_conv_attention=True):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.3611
|
||||||
|
self.shift_factor = 0.1159
|
||||||
|
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# UNetMidBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
UpSampler(512),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
UpSampler(512),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
UpSampler(256),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(256, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x),
|
||||||
|
sample,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=sample.device,
|
||||||
|
tile_dtype=sample.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
# 1. pre-process
|
||||||
|
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
time_emb = None
|
||||||
|
text_emb = None
|
||||||
|
res_stack = None
|
||||||
|
|
||||||
|
# 2. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 3. output
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, use_conv_attention=True):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.3611
|
||||||
|
self.shift_factor = 0.1159
|
||||||
|
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
DownSampler(128, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(128, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
DownSampler(256, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(256, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
DownSampler(512, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
# UNetMidBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x),
|
||||||
|
sample,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=sample.device,
|
||||||
|
tile_dtype=sample.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
# 1. pre-process
|
||||||
|
hidden_states = self.conv_in(sample)
|
||||||
|
time_emb = None
|
||||||
|
text_emb = None
|
||||||
|
res_stack = None
|
||||||
|
|
||||||
|
# 2. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 3. output
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, :16]
|
||||||
|
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def encode_video(self, sample, batch_size=8):
|
||||||
|
B = sample.shape[0]
|
||||||
|
hidden_states = []
|
||||||
|
|
||||||
|
for i in range(0, sample.shape[2], batch_size):
|
||||||
|
|
||||||
|
j = min(i + batch_size, sample.shape[2])
|
||||||
|
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||||
|
|
||||||
|
hidden_states_batch = self(sample_batch)
|
||||||
|
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||||
|
|
||||||
|
hidden_states.append(hidden_states_batch)
|
||||||
|
|
||||||
|
hidden_states = torch.concat(hidden_states, dim=2)
|
||||||
|
return hidden_states
|
||||||
|
|||||||
56
diffsynth/models/flux_value_control.py
Normal file
56
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from .general_modules import TemporalTimesteps
|
||||||
|
|
||||||
|
|
||||||
|
class MultiValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, encoders=()):
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(encoders, list):
|
||||||
|
encoders = [encoders]
|
||||||
|
self.encoders = torch.nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
def __call__(self, values, dtype):
|
||||||
|
emb = []
|
||||||
|
for encoder, value in zip(self.encoders, values):
|
||||||
|
if value is not None:
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
emb.append(encoder(value, dtype))
|
||||||
|
emb = torch.concat(emb, dim=0)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.prefer_len = prefer_len
|
||||||
|
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||||
|
self.prefer_value_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(
|
||||||
|
torch.randn(self.prefer_len, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, value, dtype):
|
||||||
|
value = value * 1000
|
||||||
|
emb = self.prefer_proj(value).to(dtype)
|
||||||
|
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||||
|
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||||
|
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||||
|
learned_embeddings = base_embeddings + positional_embedding
|
||||||
|
return learned_embeddings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SingleValueEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
146
diffsynth/models/general_modules.py
Normal file
146
diffsynth/models/general_modules.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import torch, math
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
embedding_dim: int,
|
||||||
|
flip_sin_to_cos: bool = False,
|
||||||
|
downscale_freq_shift: float = 1,
|
||||||
|
scale: float = 1,
|
||||||
|
max_period: int = 10000,
|
||||||
|
computation_device = None,
|
||||||
|
align_dtype_to_timestep = False,
|
||||||
|
):
|
||||||
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||||
|
)
|
||||||
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
if align_dtype_to_timestep:
|
||||||
|
emb = emb.to(timesteps.dtype)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
|
# scale embeddings
|
||||||
|
emb = scale * emb
|
||||||
|
|
||||||
|
# concat sine and cosine embeddings
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
|
||||||
|
# flip sine and cosine embeddings
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
|
|
||||||
|
# zero pad
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalTimesteps(torch.nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.scale = scale
|
||||||
|
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
t_emb = get_timestep_embedding(
|
||||||
|
timesteps,
|
||||||
|
self.num_channels,
|
||||||
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
|
computation_device=self.computation_device,
|
||||||
|
scale=self.scale,
|
||||||
|
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||||
|
)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||||
|
if diffusers_compatible_format:
|
||||||
|
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||||
|
else:
|
||||||
|
self.timestep_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.use_additional_t_cond = use_additional_t_cond
|
||||||
|
if use_additional_t_cond:
|
||||||
|
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
|
||||||
|
|
||||||
|
def forward(self, timestep, dtype, addition_t_cond=None):
|
||||||
|
time_emb = self.time_proj(timestep).to(dtype)
|
||||||
|
time_emb = self.timestep_embedder(time_emb)
|
||||||
|
if addition_t_cond is not None:
|
||||||
|
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||||
|
addition_t_emb = addition_t_emb.to(dtype=dtype)
|
||||||
|
time_emb = time_emb + addition_t_emb
|
||||||
|
return time_emb
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, eps, elementwise_affine=True):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype)
|
||||||
|
if self.weight is not None:
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, single=False, dual=False):
|
||||||
|
super().__init__()
|
||||||
|
self.single = single
|
||||||
|
self.dual = dual
|
||||||
|
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(torch.nn.functional.silu(emb))
|
||||||
|
if self.single:
|
||||||
|
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
elif self.dual:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||||
|
norm_x = self.norm(x)
|
||||||
|
x = norm_x * (1 + scale_msa) + shift_msa
|
||||||
|
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||||
|
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||||
@@ -1,451 +0,0 @@
|
|||||||
from .attention import Attention
|
|
||||||
from einops import repeat, rearrange
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
|
||||||
super().__init__()
|
|
||||||
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
|
||||||
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
|
||||||
self.rotary_emb_on_k = rotary_emb_on_k
|
|
||||||
self.k_cache, self.v_cache = [], []
|
|
||||||
|
|
||||||
def reshape_for_broadcast(self, freqs_cis, x):
|
|
||||||
ndim = x.ndim
|
|
||||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
||||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
||||||
|
|
||||||
def rotate_half(self, x):
|
|
||||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
|
|
||||||
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
|
||||||
xk_out = None
|
|
||||||
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
|
||||||
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
if xk is not None:
|
|
||||||
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
|
||||||
return xq_out, xk_out
|
|
||||||
|
|
||||||
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
|
||||||
# norm
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# RoPE
|
|
||||||
if self.rotary_emb_on_k:
|
|
||||||
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
|
||||||
else:
|
|
||||||
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
|
||||||
|
|
||||||
if to_cache:
|
|
||||||
self.k_cache.append(k)
|
|
||||||
self.v_cache.append(v)
|
|
||||||
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
|
||||||
k = torch.concat([k] + self.k_cache, dim=2)
|
|
||||||
v = torch.concat([v] + self.v_cache, dim=2)
|
|
||||||
self.k_cache, self.v_cache = [], []
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
class FP32_Layernorm(torch.nn.LayerNorm):
|
|
||||||
def forward(self, inputs):
|
|
||||||
origin_dtype = inputs.dtype
|
|
||||||
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FP32_SiLU(torch.nn.SiLU):
|
|
||||||
def forward(self, inputs):
|
|
||||||
origin_dtype = inputs.dtype
|
|
||||||
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTFinalLayer(torch.nn.Module):
|
|
||||||
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
|
||||||
self.adaLN_modulation = torch.nn.Sequential(
|
|
||||||
FP32_SiLU(),
|
|
||||||
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def modulate(self, x, shift, scale):
|
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, condition_emb):
|
|
||||||
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
|
||||||
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
|
||||||
hidden_states = self.linear(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_dim=1408,
|
|
||||||
condition_dim=1408,
|
|
||||||
num_heads=16,
|
|
||||||
mlp_ratio=4.3637,
|
|
||||||
text_dim=1024,
|
|
||||||
skip_connection=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
|
||||||
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
|
||||||
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
|
||||||
self.mlp = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
|
||||||
)
|
|
||||||
if skip_connection:
|
|
||||||
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
|
||||||
else:
|
|
||||||
self.skip_norm, self.skip_linear = None, None
|
|
||||||
|
|
||||||
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
|
||||||
# Long Skip Connection
|
|
||||||
if self.skip_norm is not None and self.skip_linear is not None:
|
|
||||||
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
|
||||||
hidden_states = self.skip_norm(hidden_states)
|
|
||||||
hidden_states = self.skip_linear(hidden_states)
|
|
||||||
|
|
||||||
# Self-Attention
|
|
||||||
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
|
||||||
attn_input = self.norm1(hidden_states) + shift_msa
|
|
||||||
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
|
||||||
|
|
||||||
# Cross-Attention
|
|
||||||
attn_input = self.norm3(hidden_states)
|
|
||||||
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
|
||||||
|
|
||||||
# FFN Layer
|
|
||||||
mlp_input = self.norm2(hidden_states)
|
|
||||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool(torch.nn.Module):
|
|
||||||
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
|
||||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
|
||||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
|
||||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
|
||||||
x, _ = torch.nn.functional.multi_head_attention_forward(
|
|
||||||
query=x[:1], key=x, value=x,
|
|
||||||
embed_dim_to_check=x.shape[-1],
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
q_proj_weight=self.q_proj.weight,
|
|
||||||
k_proj_weight=self.k_proj.weight,
|
|
||||||
v_proj_weight=self.v_proj.weight,
|
|
||||||
in_proj_weight=None,
|
|
||||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
||||||
bias_k=None,
|
|
||||||
bias_v=None,
|
|
||||||
add_zero_attn=False,
|
|
||||||
dropout_p=0,
|
|
||||||
out_proj_weight=self.c_proj.weight,
|
|
||||||
out_proj_bias=self.c_proj.bias,
|
|
||||||
use_separate_proj_weight=True,
|
|
||||||
training=self.training,
|
|
||||||
need_weights=False
|
|
||||||
)
|
|
||||||
return x.squeeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size=(2, 2),
|
|
||||||
in_chans=4,
|
|
||||||
embed_dim=1408,
|
|
||||||
bias=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
|
||||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
||||||
if not repeat_only:
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(
|
|
||||||
-math.log(max_period)
|
|
||||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
|
||||||
/ half
|
|
||||||
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat(
|
|
||||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding = repeat(t, "b -> b d", d=dim)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
|
|
||||||
def forward(self, t):
|
|
||||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiT(torch.nn.Module):
|
|
||||||
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Embedders
|
|
||||||
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
|
||||||
self.t5_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
|
||||||
FP32_SiLU(),
|
|
||||||
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
|
||||||
)
|
|
||||||
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
|
||||||
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
|
||||||
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
|
||||||
self.timestep_embedder = TimestepEmbedder()
|
|
||||||
self.extra_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
|
||||||
FP32_SiLU(),
|
|
||||||
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transformer blocks
|
|
||||||
self.num_layers_down = num_layers_down
|
|
||||||
self.num_layers_up = num_layers_up
|
|
||||||
self.blocks = torch.nn.ModuleList(
|
|
||||||
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
|
||||||
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output layers
|
|
||||||
self.final_layer = HunyuanDiTFinalLayer()
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
|
||||||
text_emb_mask = text_emb_mask.bool()
|
|
||||||
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
|
||||||
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
|
||||||
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
|
||||||
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
|
||||||
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
|
||||||
return text_emb
|
|
||||||
|
|
||||||
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
|
||||||
# Text embedding
|
|
||||||
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
|
||||||
|
|
||||||
# Timestep embedding
|
|
||||||
timestep_emb = self.timestep_embedder(timestep)
|
|
||||||
|
|
||||||
# Size embedding
|
|
||||||
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
|
||||||
size_emb = size_emb.view(-1, 6 * 256)
|
|
||||||
|
|
||||||
# Style embedding
|
|
||||||
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
|
||||||
|
|
||||||
# Concatenate all extra vectors
|
|
||||||
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
|
||||||
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
|
||||||
|
|
||||||
return condition_emb
|
|
||||||
|
|
||||||
def unpatchify(self, x, h, w):
|
|
||||||
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound):
|
|
||||||
_, _, H, W = data.shape
|
|
||||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
|
||||||
border_width = (H + W) // 4
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else h + 1,
|
|
||||||
pad if is_bound[1] else H - h,
|
|
||||||
pad if is_bound[2] else w + 1,
|
|
||||||
pad if is_bound[3] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
|
||||||
mask = rearrange(mask, "H W -> 1 H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
|
||||||
B, C, H, W = hidden_states.shape
|
|
||||||
|
|
||||||
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
|
||||||
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride):
|
|
||||||
for w in range(0, W, tile_stride):
|
|
||||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size, w + tile_size
|
|
||||||
if h_ > H: h, h_ = H - tile_size, H
|
|
||||||
if w_ > W: w, w_ = W - tile_size, W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in tasks:
|
|
||||||
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
|
||||||
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
|
||||||
if residual is not None:
|
|
||||||
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
|
||||||
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
|
||||||
else:
|
|
||||||
residual_batch = None
|
|
||||||
|
|
||||||
# Forward
|
|
||||||
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
|
||||||
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
|
||||||
|
|
||||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
|
||||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
|
||||||
weight[:, :, hl:hr, wl:wr] += mask
|
|
||||||
values /= weight
|
|
||||||
return values
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
|
||||||
to_cache=False,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
# Embeddings
|
|
||||||
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
|
||||||
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
|
||||||
|
|
||||||
# Input
|
|
||||||
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
|
||||||
hidden_states = self.patch_embedder(hidden_states)
|
|
||||||
|
|
||||||
# Blocks
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
if tiled:
|
|
||||||
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
|
||||||
residuals = []
|
|
||||||
for block_id, block in enumerate(self.blocks):
|
|
||||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
|
||||||
hidden_states = self.tiled_block_forward(
|
|
||||||
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
|
||||||
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride
|
|
||||||
)
|
|
||||||
if block_id < self.num_layers_down - 2:
|
|
||||||
residuals.append(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
|
||||||
else:
|
|
||||||
residuals = []
|
|
||||||
for block_id, block in enumerate(self.blocks):
|
|
||||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
|
||||||
if block_id < self.num_layers_down - 2:
|
|
||||||
residuals.append(hidden_states)
|
|
||||||
|
|
||||||
# Output
|
|
||||||
hidden_states = self.final_layer(hidden_states, condition_emb)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
|
||||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
name_ = name
|
|
||||||
name_ = name_.replace(".default_modulation.", ".modulation.")
|
|
||||||
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
|
||||||
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
|
||||||
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
|
||||||
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
|
||||||
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
|
||||||
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
|
||||||
name_ = name_.replace(".q_proj.", ".to_q.")
|
|
||||||
name_ = name_.replace(".out_proj.", ".to_out.")
|
|
||||||
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
|
||||||
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
|
||||||
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
|
||||||
name_ = name_.replace("pooler.", "t5_pooler.")
|
|
||||||
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
|
||||||
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
|
||||||
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
|
||||||
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
|
||||||
if ".kv_proj." in name_:
|
|
||||||
param_k = param[:param.shape[0]//2]
|
|
||||||
param_v = param[param.shape[0]//2:]
|
|
||||||
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
|
||||||
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
|
||||||
elif ".Wqkv." in name_:
|
|
||||||
param_q = param[:param.shape[0]//3]
|
|
||||||
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
|
||||||
param_v = param[param.shape[0]//3*2:]
|
|
||||||
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
|
||||||
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
|
||||||
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
|
||||||
elif "style_embedder" in name_:
|
|
||||||
state_dict_[name_] = param.squeeze()
|
|
||||||
else:
|
|
||||||
state_dict_[name_] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTCLIPTextEncoder(BertModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = BertConfig(
|
|
||||||
_name_or_path = "",
|
|
||||||
architectures = ["BertModel"],
|
|
||||||
attention_probs_dropout_prob = 0.1,
|
|
||||||
bos_token_id = 0,
|
|
||||||
classifier_dropout = None,
|
|
||||||
directionality = "bidi",
|
|
||||||
eos_token_id = 2,
|
|
||||||
hidden_act = "gelu",
|
|
||||||
hidden_dropout_prob = 0.1,
|
|
||||||
hidden_size = 1024,
|
|
||||||
initializer_range = 0.02,
|
|
||||||
intermediate_size = 4096,
|
|
||||||
layer_norm_eps = 1e-12,
|
|
||||||
max_position_embeddings = 512,
|
|
||||||
model_type = "bert",
|
|
||||||
num_attention_heads = 16,
|
|
||||||
num_hidden_layers = 24,
|
|
||||||
output_past = True,
|
|
||||||
pad_token_id = 0,
|
|
||||||
pooler_fc_size = 768,
|
|
||||||
pooler_num_attention_heads = 12,
|
|
||||||
pooler_num_fc_layers = 3,
|
|
||||||
pooler_size_per_head = 128,
|
|
||||||
pooler_type = "first_token_transform",
|
|
||||||
position_embedding_type = "absolute",
|
|
||||||
torch_dtype = "float32",
|
|
||||||
transformers_version = "4.37.2",
|
|
||||||
type_vocab_size = 2,
|
|
||||||
use_cache = True,
|
|
||||||
vocab_size = 47020
|
|
||||||
)
|
|
||||||
super().__init__(config, add_pooling_layer=False)
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
|
||||||
|
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=None,
|
|
||||||
token_type_ids=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
past_key_values_length=0,
|
|
||||||
)
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
embedding_output,
|
|
||||||
attention_mask=extended_attention_mask,
|
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_values=None,
|
|
||||||
use_cache=False,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
all_hidden_states = encoder_outputs.hidden_states
|
|
||||||
prompt_emb = all_hidden_states[-clip_skip]
|
|
||||||
if clip_skip > 1:
|
|
||||||
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
|
||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = T5Config(
|
|
||||||
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
|
||||||
architectures = ["MT5ForConditionalGeneration"],
|
|
||||||
classifier_dropout = 0.0,
|
|
||||||
d_ff = 5120,
|
|
||||||
d_kv = 64,
|
|
||||||
d_model = 2048,
|
|
||||||
decoder_start_token_id = 0,
|
|
||||||
dense_act_fn = "gelu_new",
|
|
||||||
dropout_rate = 0.1,
|
|
||||||
eos_token_id = 1,
|
|
||||||
feed_forward_proj = "gated-gelu",
|
|
||||||
initializer_factor = 1.0,
|
|
||||||
is_encoder_decoder = True,
|
|
||||||
is_gated_act = True,
|
|
||||||
layer_norm_epsilon = 1e-06,
|
|
||||||
model_type = "t5",
|
|
||||||
num_decoder_layers = 24,
|
|
||||||
num_heads = 32,
|
|
||||||
num_layers = 24,
|
|
||||||
output_past = True,
|
|
||||||
pad_token_id = 0,
|
|
||||||
relative_attention_max_distance = 128,
|
|
||||||
relative_attention_num_buckets = 32,
|
|
||||||
tie_word_embeddings = False,
|
|
||||||
tokenizer_class = "T5Tokenizer",
|
|
||||||
transformers_version = "4.37.2",
|
|
||||||
use_cache = True,
|
|
||||||
vocab_size = 250112
|
|
||||||
)
|
|
||||||
super().__init__(config)
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
|
||||||
outputs = super().forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
prompt_emb = outputs.hidden_states[-clip_skip]
|
|
||||||
if clip_skip > 1:
|
|
||||||
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
|
||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTT5TextEncoderStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
|
||||||
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,885 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
|
||||||
from .utils import init_weights_on_device
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Union, Tuple, List
|
|
||||||
|
|
||||||
|
|
||||||
def HunyuanVideoRope(latents):
|
|
||||||
def _to_tuple(x, dim=2):
|
|
||||||
if isinstance(x, int):
|
|
||||||
return (x,) * dim
|
|
||||||
elif len(x) == dim:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_meshgrid_nd(start, *args, dim=2):
|
|
||||||
"""
|
|
||||||
Get n-D meshgrid with start, stop and num.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
|
||||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
|
||||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
|
||||||
n-tuples.
|
|
||||||
*args: See above.
|
|
||||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
grid (np.ndarray): [dim, ...]
|
|
||||||
"""
|
|
||||||
if len(args) == 0:
|
|
||||||
# start is grid_size
|
|
||||||
num = _to_tuple(start, dim=dim)
|
|
||||||
start = (0,) * dim
|
|
||||||
stop = num
|
|
||||||
elif len(args) == 1:
|
|
||||||
# start is start, args[0] is stop, step is 1
|
|
||||||
start = _to_tuple(start, dim=dim)
|
|
||||||
stop = _to_tuple(args[0], dim=dim)
|
|
||||||
num = [stop[i] - start[i] for i in range(dim)]
|
|
||||||
elif len(args) == 2:
|
|
||||||
# start is start, args[0] is stop, args[1] is num
|
|
||||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
|
||||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
|
||||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
|
||||||
else:
|
|
||||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
|
||||||
|
|
||||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
|
||||||
axis_grid = []
|
|
||||||
for i in range(dim):
|
|
||||||
a, b, n = start[i], stop[i], num[i]
|
|
||||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
||||||
axis_grid.append(g)
|
|
||||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
|
||||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
|
||||||
|
|
||||||
return grid
|
|
||||||
|
|
||||||
|
|
||||||
def get_1d_rotary_pos_embed(
|
|
||||||
dim: int,
|
|
||||||
pos: Union[torch.FloatTensor, int],
|
|
||||||
theta: float = 10000.0,
|
|
||||||
use_real: bool = False,
|
|
||||||
theta_rescale_factor: float = 1.0,
|
|
||||||
interpolation_factor: float = 1.0,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
|
||||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
|
||||||
|
|
||||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
|
||||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
|
||||||
The returned tensor contains complex values in complex64 data type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension of the frequency tensor.
|
|
||||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
|
||||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
||||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
|
||||||
Otherwise, return complex numbers.
|
|
||||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
|
||||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
|
||||||
"""
|
|
||||||
if isinstance(pos, int):
|
|
||||||
pos = torch.arange(pos).float()
|
|
||||||
|
|
||||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
||||||
# has some connection to NTK literature
|
|
||||||
if theta_rescale_factor != 1.0:
|
|
||||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
||||||
|
|
||||||
freqs = 1.0 / (
|
|
||||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
|
||||||
) # [D/2]
|
|
||||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
|
||||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
|
||||||
if use_real:
|
|
||||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
|
||||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
else:
|
|
||||||
freqs_cis = torch.polar(
|
|
||||||
torch.ones_like(freqs), freqs
|
|
||||||
) # complex64 # [S, D/2]
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def get_nd_rotary_pos_embed(
|
|
||||||
rope_dim_list,
|
|
||||||
start,
|
|
||||||
*args,
|
|
||||||
theta=10000.0,
|
|
||||||
use_real=False,
|
|
||||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
|
||||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
|
||||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
|
||||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
|
||||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
|
||||||
*args: See above.
|
|
||||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
||||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
|
||||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
|
||||||
part and an imaginary part separately.
|
|
||||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pos_embed (torch.Tensor): [HW, D/2]
|
|
||||||
"""
|
|
||||||
|
|
||||||
grid = get_meshgrid_nd(
|
|
||||||
start, *args, dim=len(rope_dim_list)
|
|
||||||
) # [3, W, H, D] / [2, W, H]
|
|
||||||
|
|
||||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
|
||||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
|
||||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
|
||||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
|
||||||
assert len(theta_rescale_factor) == len(
|
|
||||||
rope_dim_list
|
|
||||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
|
||||||
|
|
||||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
|
||||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
|
||||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
|
||||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
|
||||||
assert len(interpolation_factor) == len(
|
|
||||||
rope_dim_list
|
|
||||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
|
||||||
|
|
||||||
# use 1/ndim of dimensions to encode grid_axis
|
|
||||||
embs = []
|
|
||||||
for i in range(len(rope_dim_list)):
|
|
||||||
emb = get_1d_rotary_pos_embed(
|
|
||||||
rope_dim_list[i],
|
|
||||||
grid[i].reshape(-1),
|
|
||||||
theta,
|
|
||||||
use_real=use_real,
|
|
||||||
theta_rescale_factor=theta_rescale_factor[i],
|
|
||||||
interpolation_factor=interpolation_factor[i],
|
|
||||||
) # 2 x [WHD, rope_dim_list[i]]
|
|
||||||
embs.append(emb)
|
|
||||||
|
|
||||||
if use_real:
|
|
||||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
|
||||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
|
||||||
return cos, sin
|
|
||||||
else:
|
|
||||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
|
||||||
[16, 56, 56],
|
|
||||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
|
||||||
theta=256,
|
|
||||||
use_real=True,
|
|
||||||
theta_rescale_factor=1,
|
|
||||||
)
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(torch.nn.Module):
|
|
||||||
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, num_heads=24):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
|
||||||
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
|
||||||
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
|
|
||||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
|
||||||
self.mlp = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * 4),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size * 4, hidden_size)
|
|
||||||
)
|
|
||||||
self.adaLN_modulation = torch.nn.Sequential(
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c, attn_mask=None):
|
|
||||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
||||||
|
|
||||||
norm_x = self.norm1(x)
|
|
||||||
qkv = self.self_attn_qkv(norm_x)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
||||||
|
|
||||||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
attn = rearrange(attn, "B H L D -> B L (H D)")
|
|
||||||
|
|
||||||
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
|
||||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SingleTokenRefiner(torch.nn.Module):
|
|
||||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
|
||||||
super().__init__()
|
|
||||||
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
|
||||||
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
|
||||||
self.c_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(in_channels, hidden_size),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
)
|
|
||||||
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
|
||||||
|
|
||||||
def forward(self, x, t, mask=None):
|
|
||||||
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
|
||||||
|
|
||||||
mask_float = mask.float().unsqueeze(-1)
|
|
||||||
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
|
||||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
|
||||||
c = timestep_aware_representations + context_aware_representations
|
|
||||||
|
|
||||||
x = self.input_embedder(x)
|
|
||||||
|
|
||||||
mask = mask.to(device=x.device, dtype=torch.bool)
|
|
||||||
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
|
||||||
mask = mask & mask.transpose(2, 3)
|
|
||||||
mask[:, :, :, 0] = True
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x, c, mask)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ModulateDiT(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size, factor=6):
|
|
||||||
super().__init__()
|
|
||||||
self.act = torch.nn.SiLU()
|
|
||||||
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear(self.act(x))
|
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, shift=None, scale=None):
|
|
||||||
if scale is None and shift is None:
|
|
||||||
return x
|
|
||||||
elif shift is None:
|
|
||||||
return x * (1 + scale.unsqueeze(1))
|
|
||||||
elif scale is None:
|
|
||||||
return x + shift.unsqueeze(1)
|
|
||||||
else:
|
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_for_broadcast(
|
|
||||||
freqs_cis,
|
|
||||||
x: torch.Tensor,
|
|
||||||
head_first=False,
|
|
||||||
):
|
|
||||||
ndim = x.ndim
|
|
||||||
assert 0 <= 1 < ndim
|
|
||||||
|
|
||||||
if isinstance(freqs_cis, tuple):
|
|
||||||
# freqs_cis: (cos, sin) in real space
|
|
||||||
if head_first:
|
|
||||||
assert freqs_cis[0].shape == (
|
|
||||||
x.shape[-2],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
|
||||||
shape = [
|
|
||||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
|
||||||
for i, d in enumerate(x.shape)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert freqs_cis[0].shape == (
|
|
||||||
x.shape[1],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
|
||||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
||||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
||||||
else:
|
|
||||||
# freqs_cis: values in complex space
|
|
||||||
if head_first:
|
|
||||||
assert freqs_cis.shape == (
|
|
||||||
x.shape[-2],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
|
||||||
shape = [
|
|
||||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
|
||||||
for i, d in enumerate(x.shape)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert freqs_cis.shape == (
|
|
||||||
x.shape[1],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
|
||||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
||||||
return freqs_cis.view(*shape)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x_real, x_imag = (
|
|
||||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
||||||
) # [B, S, H, D//2]
|
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
|
||||||
xq: torch.Tensor,
|
|
||||||
xk: torch.Tensor,
|
|
||||||
freqs_cis,
|
|
||||||
head_first: bool = False,
|
|
||||||
):
|
|
||||||
xk_out = None
|
|
||||||
if isinstance(freqs_cis, tuple):
|
|
||||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
|
||||||
# real * cos - imag * sin
|
|
||||||
# imag * cos + real * sin
|
|
||||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
|
||||||
else:
|
|
||||||
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
|
||||||
xq_ = torch.view_as_complex(
|
|
||||||
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
|
||||||
) # [B, S, H, D//2]
|
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
|
||||||
xq.device
|
|
||||||
) # [S, D//2] --> [1, S, 1, D//2]
|
|
||||||
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
|
||||||
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
|
||||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
|
||||||
xk_ = torch.view_as_complex(
|
|
||||||
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
|
||||||
) # [B, S, H, D//2]
|
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
|
||||||
|
|
||||||
return xq_out, xk_out
|
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v):
|
|
||||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = x.transpose(1, 2).flatten(2, 3)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.heads_num = heads_num
|
|
||||||
|
|
||||||
self.mod = ModulateDiT(hidden_size)
|
|
||||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
|
||||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
|
|
||||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
|
||||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
|
||||||
|
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
||||||
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
|
|
||||||
if freqs_cis is not None:
|
|
||||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
|
||||||
|
|
||||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
|
||||||
|
|
||||||
def process_ff(self, hidden_states, attn_output, mod):
|
|
||||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
|
||||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlock(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
|
||||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
|
||||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
|
||||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
|
||||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
|
||||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
|
||||||
|
|
||||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
|
||||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.heads_num = heads_num
|
|
||||||
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
|
||||||
|
|
||||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
|
||||||
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
|
||||||
|
|
||||||
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
|
|
||||||
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
|
||||||
self.modulation = ModulateDiT(hidden_size, factor=3)
|
|
||||||
|
|
||||||
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
|
||||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
|
||||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
|
||||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
|
||||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
|
||||||
q = torch.cat((q_a, q_b), dim=1)
|
|
||||||
k = torch.cat((k_a, k_b), dim=1)
|
|
||||||
|
|
||||||
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
|
||||||
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
|
||||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
|
||||||
|
|
||||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
|
||||||
return x + output * mod_gate.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class MMSingleStreamBlock(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.heads_num = heads_num
|
|
||||||
|
|
||||||
self.mod = ModulateDiT(hidden_size, factor=3)
|
|
||||||
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
|
||||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
|
||||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
norm_hidden_states = self.norm(hidden_states)
|
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
|
||||||
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
||||||
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
|
|
||||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
|
||||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
|
||||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
|
||||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
|
||||||
|
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
|
||||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
|
||||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLayer(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
|
||||||
|
|
||||||
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
|
||||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
||||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiT(torch.nn.Module):
|
|
||||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
|
||||||
super().__init__()
|
|
||||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
|
||||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
|
||||||
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
|
||||||
self.vector_in = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(768, hidden_size),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
)
|
|
||||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
|
||||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
|
||||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
|
||||||
self.final_layer = FinalLayer(hidden_size)
|
|
||||||
|
|
||||||
# TODO: remove these parameters
|
|
||||||
self.dtype = torch.bfloat16
|
|
||||||
self.patch_size = [1, 2, 2]
|
|
||||||
self.hidden_size = 3072
|
|
||||||
self.heads_num = 24
|
|
||||||
self.rope_dim_list = [16, 56, 56]
|
|
||||||
|
|
||||||
def unpatchify(self, x, T, H, W):
|
|
||||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
|
||||||
self.warm_device = warm_device
|
|
||||||
self.cold_device = cold_device
|
|
||||||
self.to(self.cold_device)
|
|
||||||
|
|
||||||
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
|
||||||
for model_name in loadmodel_names:
|
|
||||||
model = getattr(self, model_name)
|
|
||||||
if model is not None:
|
|
||||||
model.to(device)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def prepare_freqs(self, latents):
|
|
||||||
return HunyuanVideoRope(latents)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
prompt_emb: torch.Tensor = None,
|
|
||||||
text_mask: torch.Tensor = None,
|
|
||||||
pooled_prompt_emb: torch.Tensor = None,
|
|
||||||
freqs_cos: torch.Tensor = None,
|
|
||||||
freqs_sin: torch.Tensor = None,
|
|
||||||
guidance: torch.Tensor = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
B, C, T, H, W = x.shape
|
|
||||||
|
|
||||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
|
||||||
img = self.img_in(x)
|
|
||||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
|
||||||
|
|
||||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
|
||||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
|
||||||
|
|
||||||
x = torch.concat([img, txt], dim=1)
|
|
||||||
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
|
||||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
|
||||||
|
|
||||||
img = x[:, :-256]
|
|
||||||
img = self.final_layer(img, vec)
|
|
||||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_weight(s, input=None, dtype=None, device=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if bias_dtype is None:
|
|
||||||
bias_dtype = dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class quantized_layer:
|
|
||||||
class Linear(torch.nn.Linear):
|
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def block_forward_(self, x, i, j, dtype, device):
|
|
||||||
weight_ = cast_to(
|
|
||||||
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
if self.bias is None or i > 0:
|
|
||||||
bias_ = None
|
|
||||||
else:
|
|
||||||
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
|
||||||
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
|
||||||
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
|
||||||
del x_, weight_, bias_
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return y_
|
|
||||||
|
|
||||||
def block_forward(self, x, **kwargs):
|
|
||||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
|
||||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
|
||||||
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
|
||||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
|
||||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def forward(self, hidden_states, **kwargs):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype)
|
|
||||||
if self.module.weight is not None:
|
|
||||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
|
||||||
hidden_states = hidden_states * weight
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
|
||||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.weight is not None and self.bias is not None:
|
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
|
||||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
|
||||||
else:
|
|
||||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
||||||
|
|
||||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.Linear(
|
|
||||||
module.in_features, module.out_features, bias=module.bias is not None,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, torch.nn.Conv3d):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.Conv3d(
|
|
||||||
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, RMSNorm):
|
|
||||||
new_layer = quantized_layer.RMSNorm(
|
|
||||||
module,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, torch.nn.LayerNorm):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.LayerNorm(
|
|
||||||
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
else:
|
|
||||||
replace_layer(module, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
replace_layer(self, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanVideoDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
if "module" in state_dict:
|
|
||||||
state_dict = state_dict["module"]
|
|
||||||
direct_dict = {
|
|
||||||
"img_in.proj": "img_in.proj",
|
|
||||||
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
|
||||||
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
|
||||||
"vector_in.in_layer": "vector_in.0",
|
|
||||||
"vector_in.out_layer": "vector_in.2",
|
|
||||||
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
|
||||||
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
|
||||||
"txt_in.input_embedder": "txt_in.input_embedder",
|
|
||||||
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
|
||||||
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
|
||||||
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
|
||||||
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
|
||||||
"final_layer.linear": "final_layer.linear",
|
|
||||||
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
|
||||||
}
|
|
||||||
txt_suffix_dict = {
|
|
||||||
"norm1": "norm1",
|
|
||||||
"self_attn_qkv": "self_attn_qkv",
|
|
||||||
"self_attn_proj": "self_attn_proj",
|
|
||||||
"norm2": "norm2",
|
|
||||||
"mlp.fc1": "mlp.0",
|
|
||||||
"mlp.fc2": "mlp.2",
|
|
||||||
"adaLN_modulation.1": "adaLN_modulation.1",
|
|
||||||
}
|
|
||||||
double_suffix_dict = {
|
|
||||||
"img_mod.linear": "component_a.mod.linear",
|
|
||||||
"img_attn_qkv": "component_a.to_qkv",
|
|
||||||
"img_attn_q_norm": "component_a.norm_q",
|
|
||||||
"img_attn_k_norm": "component_a.norm_k",
|
|
||||||
"img_attn_proj": "component_a.to_out",
|
|
||||||
"img_mlp.fc1": "component_a.ff.0",
|
|
||||||
"img_mlp.fc2": "component_a.ff.2",
|
|
||||||
"txt_mod.linear": "component_b.mod.linear",
|
|
||||||
"txt_attn_qkv": "component_b.to_qkv",
|
|
||||||
"txt_attn_q_norm": "component_b.norm_q",
|
|
||||||
"txt_attn_k_norm": "component_b.norm_k",
|
|
||||||
"txt_attn_proj": "component_b.to_out",
|
|
||||||
"txt_mlp.fc1": "component_b.ff.0",
|
|
||||||
"txt_mlp.fc2": "component_b.ff.2",
|
|
||||||
}
|
|
||||||
single_suffix_dict = {
|
|
||||||
"linear1": ["to_qkv", "ff.0"],
|
|
||||||
"linear2": ["to_out", "ff.2"],
|
|
||||||
"q_norm": "norm_q",
|
|
||||||
"k_norm": "norm_k",
|
|
||||||
"modulation.linear": "mod.linear",
|
|
||||||
}
|
|
||||||
# single_suffix_dict = {
|
|
||||||
# "linear1": "linear1",
|
|
||||||
# "linear2": "linear2",
|
|
||||||
# "q_norm": "q_norm",
|
|
||||||
# "k_norm": "k_norm",
|
|
||||||
# "modulation.linear": "modulation.linear",
|
|
||||||
# }
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
names = name.split(".")
|
|
||||||
direct_name = ".".join(names[:-1])
|
|
||||||
if direct_name in direct_dict:
|
|
||||||
name_ = direct_dict[direct_name] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif names[0] == "double_blocks":
|
|
||||||
prefix = ".".join(names[:2])
|
|
||||||
suffix = ".".join(names[2:-1])
|
|
||||||
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif names[0] == "single_blocks":
|
|
||||||
prefix = ".".join(names[:2])
|
|
||||||
suffix = ".".join(names[2:-1])
|
|
||||||
if isinstance(single_suffix_dict[suffix], list):
|
|
||||||
if suffix == "linear1":
|
|
||||||
name_a, name_b = single_suffix_dict[suffix]
|
|
||||||
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
|
||||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
|
||||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
|
||||||
elif suffix == "linear2":
|
|
||||||
if names[-1] == "weight":
|
|
||||||
name_a, name_b = single_suffix_dict[suffix]
|
|
||||||
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
|
||||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
|
||||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
|
||||||
else:
|
|
||||||
name_a, name_b = single_suffix_dict[suffix]
|
|
||||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif names[0] == "txt_in":
|
|
||||||
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
|
||||||
suffix = ".".join(names[4:-1])
|
|
||||||
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.auto_offload = False
|
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, **kwargs):
|
|
||||||
self.auto_offload = True
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
hidden_state_skip_layer=2
|
|
||||||
):
|
|
||||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
|
||||||
inputs_embeds = embed_tokens(input_ids)
|
|
||||||
|
|
||||||
past_key_values = DynamicCache()
|
|
||||||
|
|
||||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# create position embeddings to be shared across the decoder layers
|
|
||||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
|
||||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
|
||||||
if self.auto_offload:
|
|
||||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=False,
|
|
||||||
use_cache=True,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
|
||||||
break
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
@@ -1,507 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from einops import repeat
|
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
|
||||||
) # W, H, T
|
|
||||||
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
class UpsampleCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.out_channels = out_channels or channels
|
|
||||||
self.upsample_factor = upsample_factor
|
|
||||||
self.conv = None
|
|
||||||
if use_conv:
|
|
||||||
kernel_size = 3 if kernel_size is None else kernel_size
|
|
||||||
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
|
||||||
dtype = hidden_states.dtype
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
|
|
||||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
|
||||||
if hidden_states.shape[0] >= 64:
|
|
||||||
hidden_states = hidden_states.contiguous()
|
|
||||||
|
|
||||||
# interpolate
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
|
||||||
if T > 1:
|
|
||||||
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
|
||||||
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
|
||||||
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
|
||||||
|
|
||||||
# If the input is bfloat16, we cast back to bfloat16
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
|
||||||
|
|
||||||
if self.conv:
|
|
||||||
hidden_states = self.conv(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
|
||||||
super().__init__()
|
|
||||||
self.pre_norm = True
|
|
||||||
self.in_channels = in_channels
|
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
||||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
|
||||||
|
|
||||||
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
|
||||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.nonlinearity = nn.SiLU()
|
|
||||||
|
|
||||||
self.conv_shortcut = None
|
|
||||||
if in_channels != out_channels:
|
|
||||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
|
||||||
|
|
||||||
def forward(self, input_tensor):
|
|
||||||
hidden_states = input_tensor
|
|
||||||
# conv1
|
|
||||||
hidden_states = self.norm1(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.conv1(hidden_states)
|
|
||||||
|
|
||||||
# conv2
|
|
||||||
hidden_states = self.norm2(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
|
||||||
hidden_states = self.conv2(hidden_states)
|
|
||||||
# shortcut
|
|
||||||
if self.conv_shortcut is not None:
|
|
||||||
input_tensor = (self.conv_shortcut(input_tensor))
|
|
||||||
# shortcut and scale
|
|
||||||
output_tensor = input_tensor + hidden_states
|
|
||||||
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
|
||||||
seq_len = n_frame * n_hw
|
|
||||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
|
||||||
for i in range(seq_len):
|
|
||||||
i_frame = i // n_hw
|
|
||||||
mask[i, :(i_frame + 1) * n_hw] = 0
|
|
||||||
if batch_size is not None:
|
|
||||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
num_groups=32,
|
|
||||||
dropout=0.0,
|
|
||||||
eps=1e-6,
|
|
||||||
bias=True,
|
|
||||||
residual_connection=True):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.residual_connection = residual_connection
|
|
||||||
dim_inner = head_dim * num_heads
|
|
||||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
|
||||||
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
|
||||||
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
|
||||||
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
|
||||||
|
|
||||||
def forward(self, input_tensor, attn_mask=None):
|
|
||||||
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(hidden_states)
|
|
||||||
v = self.to_v(hidden_states)
|
|
||||||
|
|
||||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
|
||||||
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
if self.residual_connection:
|
|
||||||
output_tensor = input_tensor + hidden_states
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class UNetMidBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
|
||||||
super().__init__()
|
|
||||||
resnets = [
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=in_channels,
|
|
||||||
dropout=dropout,
|
|
||||||
groups=num_groups,
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
attentions = []
|
|
||||||
attention_head_dim = attention_head_dim or in_channels
|
|
||||||
|
|
||||||
for _ in range(num_layers):
|
|
||||||
attentions.append(
|
|
||||||
Attention(
|
|
||||||
in_channels,
|
|
||||||
num_heads=in_channels // attention_head_dim,
|
|
||||||
head_dim=attention_head_dim,
|
|
||||||
num_groups=num_groups,
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
bias=True,
|
|
||||||
residual_connection=True,
|
|
||||||
))
|
|
||||||
|
|
||||||
resnets.append(
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=in_channels,
|
|
||||||
dropout=dropout,
|
|
||||||
groups=num_groups,
|
|
||||||
eps=eps,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.attentions = nn.ModuleList(attentions)
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.resnets[0](hidden_states)
|
|
||||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
|
||||||
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
|
||||||
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
|
||||||
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
|
||||||
hidden_states = resnet(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class UpDecoderBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
dropout=0.0,
|
|
||||||
num_layers=1,
|
|
||||||
eps=1e-6,
|
|
||||||
num_groups=32,
|
|
||||||
add_upsample=True,
|
|
||||||
upsample_scale_factor=(2, 2, 2),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
resnets = []
|
|
||||||
for i in range(num_layers):
|
|
||||||
cur_in_channel = in_channels if i == 0 else out_channels
|
|
||||||
resnets.append(
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=cur_in_channel,
|
|
||||||
out_channels=out_channels,
|
|
||||||
groups=num_groups,
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
))
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
|
||||||
|
|
||||||
self.upsamplers = None
|
|
||||||
if add_upsample:
|
|
||||||
self.upsamplers = nn.ModuleList([
|
|
||||||
UpsampleCausal3D(
|
|
||||||
out_channels,
|
|
||||||
use_conv=True,
|
|
||||||
out_channels=out_channels,
|
|
||||||
upsample_factor=upsample_scale_factor,
|
|
||||||
)
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
for resnet in self.resnets:
|
|
||||||
hidden_states = resnet(hidden_states)
|
|
||||||
if self.upsamplers is not None:
|
|
||||||
for upsampler in self.upsamplers:
|
|
||||||
hidden_states = upsampler(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=16,
|
|
||||||
out_channels=3,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio=4,
|
|
||||||
spatial_compression_ratio=8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layers_per_block = layers_per_block
|
|
||||||
|
|
||||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
|
||||||
self.up_blocks = nn.ModuleList([])
|
|
||||||
|
|
||||||
# mid
|
|
||||||
self.mid_block = UNetMidBlockCausal3D(
|
|
||||||
in_channels=block_out_channels[-1],
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
attention_head_dim=block_out_channels[-1],
|
|
||||||
)
|
|
||||||
|
|
||||||
# up
|
|
||||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
||||||
output_channel = reversed_block_out_channels[0]
|
|
||||||
for i in range(len(block_out_channels)):
|
|
||||||
prev_output_channel = output_channel
|
|
||||||
output_channel = reversed_block_out_channels[i]
|
|
||||||
is_final_block = i == len(block_out_channels) - 1
|
|
||||||
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
|
||||||
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
|
||||||
|
|
||||||
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
|
||||||
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
|
||||||
|
|
||||||
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
|
||||||
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
|
||||||
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
|
||||||
|
|
||||||
up_block = UpDecoderBlockCausal3D(
|
|
||||||
in_channels=prev_output_channel,
|
|
||||||
out_channels=output_channel,
|
|
||||||
dropout=dropout,
|
|
||||||
num_layers=layers_per_block + 1,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
|
||||||
upsample_scale_factor=upsample_scale_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.up_blocks.append(up_block)
|
|
||||||
prev_output_channel = output_channel
|
|
||||||
|
|
||||||
# out
|
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
|
||||||
self.conv_act = nn.SiLU()
|
|
||||||
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.conv_in(hidden_states)
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# middle
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(self.mid_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
# up
|
|
||||||
for up_block in self.up_blocks:
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(up_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# middle
|
|
||||||
hidden_states = self.mid_block(hidden_states)
|
|
||||||
# up
|
|
||||||
for up_block in self.up_blocks:
|
|
||||||
hidden_states = up_block(hidden_states)
|
|
||||||
# post-process
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEDecoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=16,
|
|
||||||
out_channels=3,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio=4,
|
|
||||||
spatial_compression_ratio=8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.decoder = DecoderCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
eps=eps,
|
|
||||||
dropout=dropout,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
num_groups=num_groups,
|
|
||||||
time_compression_ratio=time_compression_ratio,
|
|
||||||
spatial_compression_ratio=spatial_compression_ratio,
|
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
|
||||||
)
|
|
||||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
|
||||||
self.scaling_factor = 0.476986
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, latents):
|
|
||||||
latents = latents / self.scaling_factor
|
|
||||||
latents = self.post_quant_conv(latents)
|
|
||||||
dec = self.decoder(latents)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
||||||
x = torch.ones((length,))
|
|
||||||
if not left_bound:
|
|
||||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
||||||
if not right_bound:
|
|
||||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound, border_width):
|
|
||||||
_, _, T, H, W = data.shape
|
|
||||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
|
||||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
|
||||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
|
||||||
|
|
||||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
|
||||||
|
|
||||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
size_t, size_h, size_w = tile_size
|
|
||||||
stride_t, stride_h, stride_w = tile_stride
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for t in range(0, T, stride_t):
|
|
||||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
|
||||||
for h in range(0, H, stride_h):
|
|
||||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
||||||
for w in range(0, W, stride_w):
|
|
||||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
||||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
|
||||||
tasks.append((t, t_, h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
torch_dtype = self.post_quant_conv.weight.dtype
|
|
||||||
data_device = hidden_states.device
|
|
||||||
computation_device = self.post_quant_conv.weight.device
|
|
||||||
|
|
||||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
|
||||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
|
||||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
|
||||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
|
||||||
if t > 0:
|
|
||||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
|
||||||
|
|
||||||
mask = self.build_mask(
|
|
||||||
hidden_states_batch,
|
|
||||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
|
||||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
|
||||||
).to(dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
target_t = 0 if t==0 else t * 4 + 1
|
|
||||||
target_h = h * 8
|
|
||||||
target_w = w * 8
|
|
||||||
values[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += hidden_states_batch * mask
|
|
||||||
weight[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += mask
|
|
||||||
return values / weight
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
|
||||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
|
||||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanVideoVAEDecoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEDecoderStateDictConverter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
|
||||||
state_dict_[name] = state_dict[name]
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,307 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampleCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.conv(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class DownEncoderBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
dropout=0.0,
|
|
||||||
num_layers=1,
|
|
||||||
eps=1e-6,
|
|
||||||
num_groups=32,
|
|
||||||
add_downsample=True,
|
|
||||||
downsample_stride=2,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
resnets = []
|
|
||||||
for i in range(num_layers):
|
|
||||||
cur_in_channel = in_channels if i == 0 else out_channels
|
|
||||||
resnets.append(
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=cur_in_channel,
|
|
||||||
out_channels=out_channels,
|
|
||||||
groups=num_groups,
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
))
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
|
||||||
|
|
||||||
self.downsamplers = None
|
|
||||||
if add_downsample:
|
|
||||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
stride=downsample_stride,
|
|
||||||
)])
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
for resnet in self.resnets:
|
|
||||||
hidden_states = resnet(hidden_states)
|
|
||||||
|
|
||||||
if self.downsamplers is not None:
|
|
||||||
for downsampler in self.downsamplers:
|
|
||||||
hidden_states = downsampler(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int = 3,
|
|
||||||
out_channels: int = 16,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio: int = 4,
|
|
||||||
spatial_compression_ratio: int = 8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
|
||||||
self.down_blocks = nn.ModuleList([])
|
|
||||||
|
|
||||||
# down
|
|
||||||
output_channel = block_out_channels[0]
|
|
||||||
for i in range(len(block_out_channels)):
|
|
||||||
input_channel = output_channel
|
|
||||||
output_channel = block_out_channels[i]
|
|
||||||
is_final_block = i == len(block_out_channels) - 1
|
|
||||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
|
||||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
|
||||||
|
|
||||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
|
||||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
|
||||||
|
|
||||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
|
||||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
|
||||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
|
||||||
down_block = DownEncoderBlockCausal3D(
|
|
||||||
in_channels=input_channel,
|
|
||||||
out_channels=output_channel,
|
|
||||||
dropout=dropout,
|
|
||||||
num_layers=layers_per_block,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
|
||||||
downsample_stride=downsample_stride,
|
|
||||||
)
|
|
||||||
self.down_blocks.append(down_block)
|
|
||||||
|
|
||||||
# mid
|
|
||||||
self.mid_block = UNetMidBlockCausal3D(
|
|
||||||
in_channels=block_out_channels[-1],
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
attention_head_dim=block_out_channels[-1],
|
|
||||||
)
|
|
||||||
# out
|
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
|
||||||
self.conv_act = nn.SiLU()
|
|
||||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.conv_in(hidden_states)
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# down
|
|
||||||
for down_block in self.down_blocks:
|
|
||||||
torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(down_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
# middle
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(self.mid_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# down
|
|
||||||
for down_block in self.down_blocks:
|
|
||||||
hidden_states = down_block(hidden_states)
|
|
||||||
# middle
|
|
||||||
hidden_states = self.mid_block(hidden_states)
|
|
||||||
# post-process
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEEncoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=16,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio=4,
|
|
||||||
spatial_compression_ratio=8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = EncoderCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
eps=eps,
|
|
||||||
dropout=dropout,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
num_groups=num_groups,
|
|
||||||
time_compression_ratio=time_compression_ratio,
|
|
||||||
spatial_compression_ratio=spatial_compression_ratio,
|
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
|
||||||
)
|
|
||||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
|
||||||
self.scaling_factor = 0.476986
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, images):
|
|
||||||
latents = self.encoder(images)
|
|
||||||
latents = self.quant_conv(latents)
|
|
||||||
latents = latents[:, :16]
|
|
||||||
latents = latents * self.scaling_factor
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
||||||
x = torch.ones((length,))
|
|
||||||
if not left_bound:
|
|
||||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
||||||
if not right_bound:
|
|
||||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound, border_width):
|
|
||||||
_, _, T, H, W = data.shape
|
|
||||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
|
||||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
|
||||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
|
||||||
|
|
||||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
|
||||||
|
|
||||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
size_t, size_h, size_w = tile_size
|
|
||||||
stride_t, stride_h, stride_w = tile_stride
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for t in range(0, T, stride_t):
|
|
||||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
|
||||||
for h in range(0, H, stride_h):
|
|
||||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
||||||
for w in range(0, W, stride_w):
|
|
||||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
||||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
|
||||||
tasks.append((t, t_, h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
torch_dtype = self.quant_conv.weight.dtype
|
|
||||||
data_device = hidden_states.device
|
|
||||||
computation_device = self.quant_conv.weight.device
|
|
||||||
|
|
||||||
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
|
||||||
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
|
||||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
|
||||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
|
||||||
if t > 0:
|
|
||||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
|
||||||
|
|
||||||
mask = self.build_mask(
|
|
||||||
hidden_states_batch,
|
|
||||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
|
||||||
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
|
||||||
).to(dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
target_t = 0 if t==0 else t // 4 + 1
|
|
||||||
target_h = h // 8
|
|
||||||
target_w = w // 8
|
|
||||||
values[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += hidden_states_batch * mask
|
|
||||||
weight[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += mask
|
|
||||||
return values / weight
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
|
||||||
latents = latents.to(self.quant_conv.weight.dtype)
|
|
||||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
|
||||||
state_dict_[name] = state_dict[name]
|
|
||||||
return state_dict_
|
|
||||||
File diff suppressed because one or more lines are too long
902
diffsynth/models/longcat_video_dit.py
Normal file
902
diffsynth/models/longcat_video_dit.py
Normal file
@@ -0,0 +1,902 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.amp as amp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .wan_video_dit import flash_attention
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm_FP32(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def broadcat(tensors, dim=-1):
|
||||||
|
num_tensors = len(tensors)
|
||||||
|
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||||
|
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||||
|
shape_len = list(shape_lens)[0]
|
||||||
|
dim = (dim + shape_len) if dim < 0 else dim
|
||||||
|
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||||
|
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||||
|
assert all(
|
||||||
|
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||||
|
), "invalid dimensions for broadcastable concatentation"
|
||||||
|
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||||
|
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||||
|
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||||
|
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||||
|
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||||
|
return torch.cat(tensors, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
|
x1, x2 = x.unbind(dim=-1)
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return rearrange(x, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionalEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
head_dim,
|
||||||
|
cp_split_hw=None
|
||||||
|
):
|
||||||
|
"""Rotary positional embedding for 3D
|
||||||
|
Reference : https://blog.eleuther.ai/rotary-embeddings/
|
||||||
|
Paper: https://arxiv.org/pdf/2104.09864.pdf
|
||||||
|
Args:
|
||||||
|
dim: Dimension of embedding
|
||||||
|
base: Base value for exponential
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
|
||||||
|
self.cp_split_hw = cp_split_hw
|
||||||
|
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
|
||||||
|
self.base = 10000
|
||||||
|
self.freqs_dict = {}
|
||||||
|
|
||||||
|
def register_grid_size(self, grid_size):
|
||||||
|
if grid_size not in self.freqs_dict:
|
||||||
|
self.freqs_dict.update({
|
||||||
|
grid_size: self.precompute_freqs_cis_3d(grid_size)
|
||||||
|
})
|
||||||
|
|
||||||
|
def precompute_freqs_cis_3d(self, grid_size):
|
||||||
|
num_frames, height, width = grid_size
|
||||||
|
dim_t = self.head_dim - 4 * (self.head_dim // 6)
|
||||||
|
dim_h = 2 * (self.head_dim // 6)
|
||||||
|
dim_w = 2 * (self.head_dim // 6)
|
||||||
|
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||||
|
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||||
|
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||||
|
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
|
||||||
|
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
|
||||||
|
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
|
||||||
|
grid_t = torch.from_numpy(grid_t).float()
|
||||||
|
grid_h = torch.from_numpy(grid_h).float()
|
||||||
|
grid_w = torch.from_numpy(grid_w).float()
|
||||||
|
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
||||||
|
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
||||||
|
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
||||||
|
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
||||||
|
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||||
|
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||||
|
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||||
|
# (T H W D)
|
||||||
|
freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||||
|
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||||
|
# with torch.no_grad():
|
||||||
|
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
|
||||||
|
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
|
||||||
|
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||||
|
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, q, k, grid_size):
|
||||||
|
"""3D RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: [B, head, seq, head_dim]
|
||||||
|
key: [B, head, seq, head_dim]
|
||||||
|
Returns:
|
||||||
|
query and key with the same shape as input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if grid_size not in self.freqs_dict:
|
||||||
|
self.register_grid_size(grid_size)
|
||||||
|
|
||||||
|
freqs_cis = self.freqs_dict[grid_size].to(q.device)
|
||||||
|
q_, k_ = q.float(), k.float()
|
||||||
|
freqs_cis = freqs_cis.float().to(q.device)
|
||||||
|
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||||
|
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||||
|
q_ = (q_ * cos) + (rotate_half(q_) * sin)
|
||||||
|
k_ = (k_ * cos) + (rotate_half(k_) * sin)
|
||||||
|
|
||||||
|
return q_.type_as(q), k_.type_as(k)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
enable_flashattn3: bool = False,
|
||||||
|
enable_flashattn2: bool = False,
|
||||||
|
enable_xformers: bool = False,
|
||||||
|
enable_bsa: bool = False,
|
||||||
|
bsa_params: dict = None,
|
||||||
|
cp_split_hw: Optional[List[int]] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.enable_flashattn3 = enable_flashattn3
|
||||||
|
self.enable_flashattn2 = enable_flashattn2
|
||||||
|
self.enable_xformers = enable_xformers
|
||||||
|
self.enable_bsa = enable_bsa
|
||||||
|
self.bsa_params = bsa_params
|
||||||
|
self.cp_split_hw = cp_split_hw
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||||
|
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.rope_3d = RotaryPositionalEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
cp_split_hw=cp_split_hw
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_attn(self, q, k, v, shape):
|
||||||
|
q = rearrange(q, "B H S D -> B S (H D)")
|
||||||
|
k = rearrange(k, "B H S D -> B S (H D)")
|
||||||
|
v = rearrange(v, "B H S D -> B S (H D)")
|
||||||
|
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||||
|
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||||
|
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
k_cache, v_cache = k.clone(), v.clone()
|
||||||
|
|
||||||
|
q, k = self.rope_3d(q, k, shape)
|
||||||
|
|
||||||
|
# cond mode
|
||||||
|
if num_cond_latents is not None and num_cond_latents > 0:
|
||||||
|
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||||
|
# process the condition tokens
|
||||||
|
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
|
||||||
|
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
|
||||||
|
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
|
||||||
|
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
|
||||||
|
# process the noise tokens
|
||||||
|
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
|
||||||
|
x_noise = self._process_attn(q_noise, k, v, shape)
|
||||||
|
# merge x_cond and x_noise
|
||||||
|
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
|
||||||
|
else:
|
||||||
|
x = self._process_attn(q, k, v, shape)
|
||||||
|
|
||||||
|
x_output_shape = (B, N, C)
|
||||||
|
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||||
|
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
return x, (k_cache, v_cache)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||||
|
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
T, H, W = shape
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
|
||||||
|
if k_cache.shape[0] == 1:
|
||||||
|
k_cache = k_cache.repeat(B, 1, 1, 1)
|
||||||
|
v_cache = v_cache.repeat(B, 1, 1, 1)
|
||||||
|
|
||||||
|
if num_cond_latents is not None and num_cond_latents > 0:
|
||||||
|
k_full = torch.cat([k_cache, k], dim=2).contiguous()
|
||||||
|
v_full = torch.cat([v_cache, v], dim=2).contiguous()
|
||||||
|
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
|
||||||
|
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
|
||||||
|
q = q_padding[:, :, -N:].contiguous()
|
||||||
|
|
||||||
|
x = self._process_attn(q, k_full, v_full, shape)
|
||||||
|
|
||||||
|
x_output_shape = (B, N, C)
|
||||||
|
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||||
|
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadCrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
enable_flashattn3=False,
|
||||||
|
enable_flashattn2=False,
|
||||||
|
enable_xformers=False,
|
||||||
|
):
|
||||||
|
super(MultiHeadCrossAttention, self).__init__()
|
||||||
|
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.q_linear = nn.Linear(dim, dim)
|
||||||
|
self.kv_linear = nn.Linear(dim, dim * 2)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.enable_flashattn3 = enable_flashattn3
|
||||||
|
self.enable_flashattn2 = enable_flashattn2
|
||||||
|
self.enable_xformers = enable_xformers
|
||||||
|
|
||||||
|
def _process_cross_attn(self, x, cond, kv_seqlen):
|
||||||
|
B, N, C = x.shape
|
||||||
|
assert C == self.dim and cond.shape[2] == self.dim
|
||||||
|
|
||||||
|
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||||
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
|
k, v = kv.unbind(2)
|
||||||
|
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
q = rearrange(q, "B S H D -> B S (H D)")
|
||||||
|
k = rearrange(k, "B S H D -> B S (H D)")
|
||||||
|
v = rearrange(v, "B S H D -> B S (H D)")
|
||||||
|
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||||
|
|
||||||
|
x = x.view(B, -1, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
|
||||||
|
"""
|
||||||
|
x: [B, N, C]
|
||||||
|
cond: [B, M, C]
|
||||||
|
"""
|
||||||
|
if num_cond_latents is None or num_cond_latents == 0:
|
||||||
|
return self._process_cross_attn(x, cond, kv_seqlen)
|
||||||
|
else:
|
||||||
|
B, N, C = x.shape
|
||||||
|
if num_cond_latents is not None and num_cond_latents > 0:
|
||||||
|
assert shape is not None, "SHOULD pass in the shape"
|
||||||
|
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||||
|
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
|
||||||
|
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
|
||||||
|
output = torch.cat([
|
||||||
|
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
|
||||||
|
output_noise
|
||||||
|
], dim=1).contiguous()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm_FP32(nn.LayerNorm):
|
||||||
|
def __init__(self, dim, eps, elementwise_affine):
|
||||||
|
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
origin_dtype = inputs.dtype
|
||||||
|
out = F.layer_norm(
|
||||||
|
inputs.float(),
|
||||||
|
self.normalized_shape,
|
||||||
|
None if self.weight is None else self.weight.float(),
|
||||||
|
None if self.bias is None else self.bias.float() ,
|
||||||
|
self.eps
|
||||||
|
).to(origin_dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def modulate_fp32(norm_func, x, shift, scale):
|
||||||
|
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
|
||||||
|
# ensure the modulation params be fp32
|
||||||
|
assert shift.dtype == torch.float32, scale.dtype == torch.float32
|
||||||
|
dtype = x.dtype
|
||||||
|
x = norm_func(x.to(torch.float32))
|
||||||
|
x = x * (scale + 1) + shift
|
||||||
|
x = x.to(dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer_FP32(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_patch = num_patch
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.adaln_tembed_dim = adaln_tembed_dim
|
||||||
|
|
||||||
|
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
def forward(self, x, t, latent_shape):
|
||||||
|
# timestep shape: [B, T, C]
|
||||||
|
assert t.dtype == torch.float32
|
||||||
|
B, N, C = x.shape
|
||||||
|
T, _, _ = latent_shape
|
||||||
|
|
||||||
|
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||||
|
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||||
|
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardSwiGLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, t_embed_dim, frequency_embedding_size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.t_embed_dim = t_embed_dim
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||||
|
freqs = freqs.to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, dtype):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
|
if t_freq.dtype != dtype:
|
||||||
|
t_freq = t_freq.to(dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.y_proj = nn.Sequential(
|
||||||
|
nn.Linear(in_channels, hidden_size, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
B, _, N, C = caption.shape
|
||||||
|
caption = self.y_proj(caption)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed3D(nn.Module):
|
||||||
|
"""Video to Patch Embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_size (int): Patch token size. Default: (2,4,4).
|
||||||
|
in_chans (int): Number of input video channels. Default: 3.
|
||||||
|
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=(2, 4, 4),
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=96,
|
||||||
|
norm_layer=None,
|
||||||
|
flatten=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.flatten = flatten
|
||||||
|
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
# padding
|
||||||
|
_, _, D, H, W = x.size()
|
||||||
|
if W % self.patch_size[2] != 0:
|
||||||
|
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||||
|
if H % self.patch_size[1] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||||
|
if D % self.patch_size[0] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||||
|
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = self.proj(x) # (B C T H W)
|
||||||
|
if self.norm is not None:
|
||||||
|
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatSingleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: int,
|
||||||
|
adaln_tembed_dim: int,
|
||||||
|
enable_flashattn3: bool = False,
|
||||||
|
enable_flashattn2: bool = False,
|
||||||
|
enable_xformers: bool = False,
|
||||||
|
enable_bsa: bool = False,
|
||||||
|
bsa_params=None,
|
||||||
|
cp_split_hw=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
# scale and gate modulation
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
enable_flashattn3=enable_flashattn3,
|
||||||
|
enable_flashattn2=enable_flashattn2,
|
||||||
|
enable_xformers=enable_xformers,
|
||||||
|
enable_bsa=enable_bsa,
|
||||||
|
bsa_params=bsa_params,
|
||||||
|
cp_split_hw=cp_split_hw
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadCrossAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
enable_flashattn3=enable_flashattn3,
|
||||||
|
enable_flashattn2=enable_flashattn2,
|
||||||
|
enable_xformers=enable_xformers,
|
||||||
|
)
|
||||||
|
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
|
||||||
|
|
||||||
|
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
|
||||||
|
"""
|
||||||
|
x: [B, N, C]
|
||||||
|
y: [1, N_valid_tokens, C]
|
||||||
|
t: [B, T, C_t]
|
||||||
|
y_seqlen: [B]; type of a list
|
||||||
|
latent_shape: latent shape of a single item
|
||||||
|
"""
|
||||||
|
x_dtype = x.dtype
|
||||||
|
|
||||||
|
B, N, C = x.shape
|
||||||
|
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||||
|
|
||||||
|
# compute modulation params in fp32
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
shift_msa, scale_msa, gate_msa, \
|
||||||
|
shift_mlp, scale_mlp, gate_mlp = \
|
||||||
|
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||||
|
|
||||||
|
# self attn with modulation
|
||||||
|
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
|
||||||
|
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
|
||||||
|
else:
|
||||||
|
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
x_s, kv_cache = attn_outputs
|
||||||
|
else:
|
||||||
|
x_s = attn_outputs
|
||||||
|
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||||
|
x = x.to(x_dtype)
|
||||||
|
|
||||||
|
# cross attn
|
||||||
|
if not skip_crs_attn:
|
||||||
|
if kv_cache is not None:
|
||||||
|
num_cond_latents = None
|
||||||
|
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
|
||||||
|
|
||||||
|
# ffn with modulation
|
||||||
|
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||||
|
x_s = self.ffn(x_m)
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||||
|
x = x.to(x_dtype)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
return x, kv_cache
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 16,
|
||||||
|
out_channels: int = 16,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
depth: int = 48,
|
||||||
|
num_heads: int = 32,
|
||||||
|
caption_channels: int = 4096,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
adaln_tembed_dim: int = 512,
|
||||||
|
frequency_embedding_size: int = 256,
|
||||||
|
# default params
|
||||||
|
patch_size: Tuple[int] = (1, 2, 2),
|
||||||
|
# attention config
|
||||||
|
enable_flashattn3: bool = False,
|
||||||
|
enable_flashattn2: bool = True,
|
||||||
|
enable_xformers: bool = False,
|
||||||
|
enable_bsa: bool = False,
|
||||||
|
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
|
||||||
|
cp_split_hw: Optional[List[int]] = [1, 1],
|
||||||
|
text_tokens_zero_pad: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.cp_split_hw = cp_split_hw
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||||
|
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
|
||||||
|
self.y_embedder = CaptionEmbedder(
|
||||||
|
in_channels=caption_channels,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
LongCatSingleStreamBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
adaln_tembed_dim=adaln_tembed_dim,
|
||||||
|
enable_flashattn3=enable_flashattn3,
|
||||||
|
enable_flashattn2=enable_flashattn2,
|
||||||
|
enable_xformers=enable_xformers,
|
||||||
|
enable_bsa=enable_bsa,
|
||||||
|
bsa_params=bsa_params,
|
||||||
|
cp_split_hw=cp_split_hw
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer_FP32(
|
||||||
|
hidden_size,
|
||||||
|
np.prod(self.patch_size),
|
||||||
|
out_channels,
|
||||||
|
adaln_tembed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.text_tokens_zero_pad = text_tokens_zero_pad
|
||||||
|
|
||||||
|
self.lora_dict = {}
|
||||||
|
self.active_loras = []
|
||||||
|
|
||||||
|
def enable_loras(self, lora_key_list=[]):
|
||||||
|
self.disable_all_loras()
|
||||||
|
|
||||||
|
module_loras = {} # {module_name: [lora1, lora2, ...]}
|
||||||
|
model_device = next(self.parameters()).device
|
||||||
|
model_dtype = next(self.parameters()).dtype
|
||||||
|
|
||||||
|
for lora_key in lora_key_list:
|
||||||
|
if lora_key in self.lora_dict:
|
||||||
|
for lora in self.lora_dict[lora_key].loras:
|
||||||
|
lora.to(model_device, dtype=model_dtype, non_blocking=True)
|
||||||
|
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
||||||
|
if module_name not in module_loras:
|
||||||
|
module_loras[module_name] = []
|
||||||
|
module_loras[module_name].append(lora)
|
||||||
|
self.active_loras.append(lora_key)
|
||||||
|
|
||||||
|
for module_name, loras in module_loras.items():
|
||||||
|
module = self._get_module_by_name(module_name)
|
||||||
|
if not hasattr(module, 'org_forward'):
|
||||||
|
module.org_forward = module.forward
|
||||||
|
module.forward = self._create_multi_lora_forward(module, loras)
|
||||||
|
|
||||||
|
def _create_multi_lora_forward(self, module, loras):
|
||||||
|
def multi_lora_forward(x, *args, **kwargs):
|
||||||
|
weight_dtype = x.dtype
|
||||||
|
org_output = module.org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
total_lora_output = 0
|
||||||
|
for lora in loras:
|
||||||
|
if lora.use_lora:
|
||||||
|
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
|
||||||
|
lx = lora.lora_up(lx)
|
||||||
|
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
|
||||||
|
total_lora_output += lora_output
|
||||||
|
|
||||||
|
return org_output + total_lora_output
|
||||||
|
|
||||||
|
return multi_lora_forward
|
||||||
|
|
||||||
|
def _get_module_by_name(self, module_name):
|
||||||
|
try:
|
||||||
|
module = self
|
||||||
|
for part in module_name.split('.'):
|
||||||
|
module = getattr(module, part)
|
||||||
|
return module
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
|
||||||
|
|
||||||
|
def disable_all_loras(self):
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if hasattr(module, 'org_forward'):
|
||||||
|
module.forward = module.org_forward
|
||||||
|
delattr(module, 'org_forward')
|
||||||
|
|
||||||
|
for lora_key, lora_network in self.lora_dict.items():
|
||||||
|
for lora in lora_network.loras:
|
||||||
|
lora.to("cpu")
|
||||||
|
|
||||||
|
self.active_loras.clear()
|
||||||
|
|
||||||
|
def enable_bsa(self,):
|
||||||
|
for block in self.blocks:
|
||||||
|
block.attn.enable_bsa = True
|
||||||
|
|
||||||
|
def disable_bsa(self,):
|
||||||
|
for block in self.blocks:
|
||||||
|
block.attn.enable_bsa = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
num_cond_latents=0,
|
||||||
|
return_kv=False,
|
||||||
|
kv_cache_dict={},
|
||||||
|
skip_crs_attn=False,
|
||||||
|
offload_kv_cache=False,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
B, _, T, H, W = hidden_states.shape
|
||||||
|
|
||||||
|
N_t = T // self.patch_size[0]
|
||||||
|
N_h = H // self.patch_size[1]
|
||||||
|
N_w = W // self.patch_size[2]
|
||||||
|
|
||||||
|
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
|
||||||
|
|
||||||
|
# expand the shape of timestep from [B] to [B, T]
|
||||||
|
if len(timestep.shape) == 1:
|
||||||
|
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
|
||||||
|
timestep[:, :num_cond_latents] = 0
|
||||||
|
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
timestep = timestep.to(dtype)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||||
|
|
||||||
|
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||||
|
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||||
|
|
||||||
|
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||||
|
|
||||||
|
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
|
||||||
|
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
|
||||||
|
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
|
||||||
|
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
|
||||||
|
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
|
||||||
|
else:
|
||||||
|
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
|
||||||
|
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
|
||||||
|
|
||||||
|
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||||
|
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
|
||||||
|
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
|
||||||
|
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
kv_cache_dict_ret = {}
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
block_outputs = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
x=hidden_states,
|
||||||
|
y=encoder_hidden_states,
|
||||||
|
t=t,
|
||||||
|
y_seqlen=y_seqlens,
|
||||||
|
latent_shape=(N_t, N_h, N_w),
|
||||||
|
num_cond_latents=num_cond_latents,
|
||||||
|
return_kv=return_kv,
|
||||||
|
kv_cache=kv_cache_dict.get(i, None),
|
||||||
|
skip_crs_attn=skip_crs_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
hidden_states, kv_cache = block_outputs
|
||||||
|
if offload_kv_cache:
|
||||||
|
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
|
||||||
|
else:
|
||||||
|
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
|
||||||
|
else:
|
||||||
|
hidden_states = block_outputs
|
||||||
|
|
||||||
|
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
|
||||||
|
|
||||||
|
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||||
|
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
|
||||||
|
|
||||||
|
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
|
||||||
|
|
||||||
|
# cast to float32 for better accuracy
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
return hidden_states, kv_cache_dict_ret
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(self, x, N_t, N_h, N_w):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): of shape [B, N, C]
|
||||||
|
|
||||||
|
Return:
|
||||||
|
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
||||||
|
"""
|
||||||
|
T_p, H_p, W_p = self.patch_size
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
||||||
|
N_t=N_t,
|
||||||
|
N_h=N_h,
|
||||||
|
N_w=N_w,
|
||||||
|
T_p=T_p,
|
||||||
|
H_p=H_p,
|
||||||
|
W_p=W_p,
|
||||||
|
C_out=self.out_channels,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return LongCatVideoTransformer3DModelDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatVideoTransformer3DModelDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sd3_dit import SD3DiT
|
|
||||||
from .flux_dit import FluxDiT
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
from .cog_dit import CogDiT
|
|
||||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAFromCivitai:
|
|
||||||
def __init__(self):
|
|
||||||
self.supported_model_classes = []
|
|
||||||
self.lora_prefix = []
|
|
||||||
self.renamed_lora_prefix = {}
|
|
||||||
self.special_keys = {}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_up" in key:
|
|
||||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
|
||||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
|
||||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_up" not in key:
|
|
||||||
continue
|
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
|
||||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
|
||||||
for special_key in self.special_keys:
|
|
||||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
|
||||||
continue
|
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
|
||||||
keys.pop(keys.index("lora_B"))
|
|
||||||
target_name = ".".join(keys)
|
|
||||||
target_name = target_name[len(lora_prefix):]
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
|
||||||
if model_resource == "diffusers":
|
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
|
||||||
elif model_resource == "civitai":
|
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
|
||||||
if isinstance(state_dict_lora, tuple):
|
|
||||||
state_dict_lora = state_dict_lora[0]
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
|
||||||
for name in state_dict_lora:
|
|
||||||
fp8=False
|
|
||||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
|
||||||
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
|
||||||
fp8=True
|
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
|
||||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
|
||||||
if fp8:
|
|
||||||
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
for model_resource in ["diffusers", "civitai"]:
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
|
||||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
|
||||||
else model.__class__.state_dict_converter().from_civitai
|
|
||||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
|
||||||
if isinstance(state_dict_lora_, tuple):
|
|
||||||
state_dict_lora_ = state_dict_lora_[0]
|
|
||||||
if len(state_dict_lora_) == 0:
|
|
||||||
continue
|
|
||||||
for name in state_dict_lora_:
|
|
||||||
if name not in state_dict_model:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return lora_prefix, model_resource
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
|
||||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
|
||||||
self.special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
"text.model": "text_model",
|
|
||||||
"self.attn.q.proj": "self_attn.q_proj",
|
|
||||||
"self.attn.k.proj": "self_attn.k_proj",
|
|
||||||
"self.attn.v.proj": "self_attn.v_proj",
|
|
||||||
"self.attn.out.proj": "self_attn.out_proj",
|
|
||||||
"input.blocks": "model.diffusion_model.input_blocks",
|
|
||||||
"middle.block": "model.diffusion_model.middle_block",
|
|
||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
|
||||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
|
||||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
|
||||||
self.special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
|
||||||
"self.attn.q.proj": "self_attn.q_proj",
|
|
||||||
"self.attn.k.proj": "self_attn.k_proj",
|
|
||||||
"self.attn.v.proj": "self_attn.v_proj",
|
|
||||||
"self.attn.out.proj": "self_attn.out_proj",
|
|
||||||
"input.blocks": "model.diffusion_model.input_blocks",
|
|
||||||
"middle.block": "model.diffusion_model.middle_block",
|
|
||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
|
||||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [FluxDiT, FluxDiT]
|
|
||||||
self.lora_prefix = ["lora_unet_", "transformer."]
|
|
||||||
self.renamed_lora_prefix = {}
|
|
||||||
self.special_keys = {
|
|
||||||
"single.blocks": "single_blocks",
|
|
||||||
"double.blocks": "double_blocks",
|
|
||||||
"img.attn": "img_attn",
|
|
||||||
"img.mlp": "img_mlp",
|
|
||||||
"img.mod": "img_mod",
|
|
||||||
"txt.attn": "txt_attn",
|
|
||||||
"txt.mlp": "txt_mlp",
|
|
||||||
"txt.mod": "txt_mod",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
|
||||||
def __init__(self):
|
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
|
||||||
device, torch_dtype = None, None
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
device, torch_dtype = param.device, param.dtype
|
|
||||||
break
|
|
||||||
return device, torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
|
||||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
|
||||||
if len(keys) > keys.index("lora_B") + 2:
|
|
||||||
keys.pop(keys.index("lora_B") + 1)
|
|
||||||
keys.pop(keys.index("lora_B"))
|
|
||||||
target_name = ".".join(keys)
|
|
||||||
if target_name not in target_state_dict:
|
|
||||||
return {}
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
|
||||||
for name in state_dict_lora:
|
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
|
||||||
dtype=state_dict_model[name].dtype,
|
|
||||||
device=state_dict_model[name].device
|
|
||||||
)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for model_class in self.supported_model_classes:
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora_) > 0:
|
|
||||||
return "", ""
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
|
||||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
|
||||||
self.special_keys = {}
|
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
|
||||||
prefix_rename_dict = {
|
|
||||||
"single_blocks": "lora_unet_single_blocks",
|
|
||||||
"blocks": "lora_unet_double_blocks",
|
|
||||||
}
|
|
||||||
middle_rename_dict = {
|
|
||||||
"norm.linear": "modulation_lin",
|
|
||||||
"to_qkv_mlp": "linear1",
|
|
||||||
"proj_out": "linear2",
|
|
||||||
|
|
||||||
"norm1_a.linear": "img_mod_lin",
|
|
||||||
"norm1_b.linear": "txt_mod_lin",
|
|
||||||
"attn.a_to_qkv": "img_attn_qkv",
|
|
||||||
"attn.b_to_qkv": "txt_attn_qkv",
|
|
||||||
"attn.a_to_out": "img_attn_proj",
|
|
||||||
"attn.b_to_out": "txt_attn_proj",
|
|
||||||
"ff_a.0": "img_mlp_0",
|
|
||||||
"ff_a.2": "img_mlp_2",
|
|
||||||
"ff_b.0": "txt_mlp_0",
|
|
||||||
"ff_b.2": "txt_mlp_2",
|
|
||||||
}
|
|
||||||
suffix_rename_dict = {
|
|
||||||
"lora_B.weight": "lora_up.weight",
|
|
||||||
"lora_A.weight": "lora_down.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
names = name.split(".")
|
|
||||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
|
||||||
names.pop(-2)
|
|
||||||
prefix = names[0]
|
|
||||||
middle = ".".join(names[2:-2])
|
|
||||||
suffix = ".".join(names[-2:])
|
|
||||||
block_id = names[1]
|
|
||||||
if middle not in middle_rename_dict:
|
|
||||||
continue
|
|
||||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
if rename.endswith("lora_up.weight"):
|
|
||||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_diffsynth_format(state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
|
||||||
}
|
|
||||||
def guess_block_id(name):
|
|
||||||
names = name.split("_")
|
|
||||||
for i in names:
|
|
||||||
if i.isdigit():
|
|
||||||
return i, name.replace(f"_{i}_", "_blockid_")
|
|
||||||
return None, None
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
block_id, source_name = guess_block_id(name)
|
|
||||||
if source_name in rename_dict:
|
|
||||||
target_name = rename_dict[source_name]
|
|
||||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
|
||||||
state_dict_[target_name] = param
|
|
||||||
else:
|
|
||||||
state_dict_[name] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
|
||||||
1408
diffsynth/models/ltx2_audio_vae.py
Normal file
1408
diffsynth/models/ltx2_audio_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
371
diffsynth/models/ltx2_common.py
Normal file
371
diffsynth/models/ltx2_common.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import NamedTuple, Protocol, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPixelShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
frames: int
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
fps: float
|
||||||
|
|
||||||
|
|
||||||
|
class SpatioTemporalScaleFactors(NamedTuple):
|
||||||
|
"""
|
||||||
|
Describes the spatiotemporal downscaling between decoded video space and
|
||||||
|
the corresponding VAE latent grid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
time: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||||
|
return cls(time=8, width=32, height=32)
|
||||||
|
|
||||||
|
|
||||||
|
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLatentShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of the tensor representing video in VAE latent space.
|
||||||
|
The latent representation is a 5D tensor with dimensions ordered as
|
||||||
|
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||||
|
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
|
||||||
|
def to_torch_shape(self) -> torch.Size:
|
||||||
|
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||||
|
return VideoLatentShape(
|
||||||
|
batch=shape[0],
|
||||||
|
channels=shape[1],
|
||||||
|
frames=shape[2],
|
||||||
|
height=shape[3],
|
||||||
|
width=shape[4],
|
||||||
|
)
|
||||||
|
|
||||||
|
def mask_shape(self) -> "VideoLatentShape":
|
||||||
|
return self._replace(channels=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pixel_shape(
|
||||||
|
shape: VideoPixelShape,
|
||||||
|
latent_channels: int = 128,
|
||||||
|
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||||
|
) -> "VideoLatentShape":
|
||||||
|
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||||
|
height = shape.height // scale_factors[1]
|
||||||
|
width = shape.width // scale_factors[2]
|
||||||
|
|
||||||
|
return VideoLatentShape(
|
||||||
|
batch=shape.batch,
|
||||||
|
channels=latent_channels,
|
||||||
|
frames=frames,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||||
|
return self._replace(
|
||||||
|
channels=3,
|
||||||
|
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||||
|
height=self.height * scale_factors.height,
|
||||||
|
width=self.width * scale_factors.width,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLatentShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||||
|
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
mel_bins: int
|
||||||
|
|
||||||
|
def to_torch_shape(self) -> torch.Size:
|
||||||
|
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||||
|
|
||||||
|
def mask_shape(self) -> "AudioLatentShape":
|
||||||
|
return self._replace(channels=1, mel_bins=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||||
|
return AudioLatentShape(
|
||||||
|
batch=shape[0],
|
||||||
|
channels=shape[1],
|
||||||
|
frames=shape[2],
|
||||||
|
mel_bins=shape[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_duration(
|
||||||
|
batch: int,
|
||||||
|
duration: float,
|
||||||
|
channels: int = 8,
|
||||||
|
mel_bins: int = 16,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
) -> "AudioLatentShape":
|
||||||
|
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||||
|
|
||||||
|
return AudioLatentShape(
|
||||||
|
batch=batch,
|
||||||
|
channels=channels,
|
||||||
|
frames=round(duration * latents_per_second),
|
||||||
|
mel_bins=mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_video_pixel_shape(
|
||||||
|
shape: VideoPixelShape,
|
||||||
|
channels: int = 8,
|
||||||
|
mel_bins: int = 16,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
) -> "AudioLatentShape":
|
||||||
|
return AudioLatentShape.from_duration(
|
||||||
|
batch=shape.batch,
|
||||||
|
duration=float(shape.frames) / float(shape.fps),
|
||||||
|
channels=channels,
|
||||||
|
mel_bins=mel_bins,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
hop_length=hop_length,
|
||||||
|
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LatentState:
|
||||||
|
"""
|
||||||
|
State of latents during the diffusion denoising process.
|
||||||
|
Attributes:
|
||||||
|
latent: The current noisy latent tensor being denoised.
|
||||||
|
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||||
|
positions: Positional indices for each latent element, used for positional embeddings.
|
||||||
|
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent: torch.Tensor
|
||||||
|
denoise_mask: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
clean_latent: torch.Tensor
|
||||||
|
|
||||||
|
def clone(self) -> "LatentState":
|
||||||
|
return LatentState(
|
||||||
|
latent=self.latent.clone(),
|
||||||
|
denoise_mask=self.denoise_mask.clone(),
|
||||||
|
positions=self.positions.clone(),
|
||||||
|
clean_latent=self.clean_latent.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormType(Enum):
|
||||||
|
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||||
|
|
||||||
|
GROUP = "group"
|
||||||
|
PIXEL = "pixel"
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Per-pixel (per-location) RMS normalization layer.
|
||||||
|
For each element along the chosen dimension, this layer normalizes the tensor
|
||||||
|
by the root-mean-square of its values across that dimension:
|
||||||
|
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: Dimension along which to compute the RMS (typically channels).
|
||||||
|
eps: Small constant added for numerical stability.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply RMS normalization along the configured dimension.
|
||||||
|
"""
|
||||||
|
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||||
|
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||||
|
# Normalize by the root-mean-square (RMS).
|
||||||
|
rms = torch.sqrt(mean_sq + self.eps)
|
||||||
|
return x / rms
|
||||||
|
|
||||||
|
|
||||||
|
def build_normalization_layer(
|
||||||
|
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create a normalization layer based on the normalization type.
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
num_groups: Number of groups for group normalization
|
||||||
|
normtype: Type of normalization: "group" or "pixel"
|
||||||
|
Returns:
|
||||||
|
A normalization layer
|
||||||
|
"""
|
||||||
|
if normtype == NormType.GROUP:
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
if normtype == NormType.PIXEL:
|
||||||
|
return PixelNorm(dim=1, eps=1e-6)
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||||
|
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||||
|
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||||
|
shape and forwards `weight` and `eps`.
|
||||||
|
"""
|
||||||
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Modality:
|
||||||
|
"""
|
||||||
|
Input data for a single modality (video or audio) in the transformer.
|
||||||
|
Bundles the latent tokens, timestep embeddings, positional information,
|
||||||
|
and text conditioning context for processing by the diffusion transformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent: (
|
||||||
|
torch.Tensor
|
||||||
|
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||||
|
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||||
|
positions: (
|
||||||
|
torch.Tensor
|
||||||
|
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||||
|
context: torch.Tensor
|
||||||
|
enabled: bool = True
|
||||||
|
context_mask: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def to_denoised(
|
||||||
|
sample: torch.Tensor,
|
||||||
|
velocity: torch.Tensor,
|
||||||
|
sigma: float | torch.Tensor,
|
||||||
|
calc_dtype: torch.dtype = torch.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert the sample and its denoising velocity to denoised sample.
|
||||||
|
Returns:
|
||||||
|
Denoised sample
|
||||||
|
"""
|
||||||
|
if isinstance(sigma, torch.Tensor):
|
||||||
|
sigma = sigma.to(calc_dtype)
|
||||||
|
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Patchifier(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def patchify(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Convert latent tensors into flattened patch tokens.
|
||||||
|
Args:
|
||||||
|
latents: Latent tensor to patchify.
|
||||||
|
Returns:
|
||||||
|
Flattened patch tokens tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def unpatchify(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
||||||
|
Args:
|
||||||
|
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
||||||
|
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
||||||
|
VideoLatentShape.
|
||||||
|
Returns:
|
||||||
|
Dense latent tensor restored from the flattened representation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def patch_size(self) -> Tuple[int, int, int]:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_patch_grid_bounds(
|
||||||
|
self,
|
||||||
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Compute metadata describing where each latent patch resides within the
|
||||||
|
grid specified by `output_shape`.
|
||||||
|
Args:
|
||||||
|
output_shape: Target grid layout for the patches.
|
||||||
|
device: Target device for the returned tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixel_coords(
|
||||||
|
latent_coords: torch.Tensor,
|
||||||
|
scale_factors: SpatioTemporalScaleFactors,
|
||||||
|
causal_fix: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||||
|
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||||
|
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||||
|
Args:
|
||||||
|
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||||
|
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||||
|
per axis.
|
||||||
|
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||||
|
that treat frame zero differently still yield non-negative timestamps.
|
||||||
|
"""
|
||||||
|
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||||
|
broadcast_shape = [1] * latent_coords.ndim
|
||||||
|
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||||
|
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||||
|
|
||||||
|
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||||
|
pixel_coords = latent_coords * scale_tensor
|
||||||
|
|
||||||
|
if causal_fix:
|
||||||
|
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||||
|
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||||
|
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
|
|
||||||
|
return pixel_coords
|
||||||
1451
diffsynth/models/ltx2_dit.py
Normal file
1451
diffsynth/models/ltx2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
366
diffsynth/models/ltx2_text_encoder.py
Normal file
366
diffsynth/models/ltx2_text_encoder.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||||
|
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||||
|
FeedForward)
|
||||||
|
from .ltx2_common import rms_norm
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
|
||||||
|
def __init__(self):
|
||||||
|
config = Gemma3Config(
|
||||||
|
**{
|
||||||
|
"architectures": ["Gemma3ForConditionalGeneration"],
|
||||||
|
"boi_token_index": 255999,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"eoi_token_index": 256000,
|
||||||
|
"eos_token_id": [1, 106],
|
||||||
|
"image_token_index": 262144,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"mm_tokens_per_image": 256,
|
||||||
|
"model_type": "gemma3",
|
||||||
|
"text_config": {
|
||||||
|
"_sliding_window_pattern": 6,
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"attn_logit_softcapping": None,
|
||||||
|
"cache_implementation": "hybrid",
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"final_logit_softcapping": None,
|
||||||
|
"head_dim": 256,
|
||||||
|
"hidden_activation": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 3840,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 15360,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
|
||||||
|
],
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "gemma3_text",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 48,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"query_pre_attn_scalar": 256,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_local_base_freq": 10000,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"rope_type": "linear"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": 1024,
|
||||||
|
"sliding_window_pattern": 6,
|
||||||
|
"use_bidirectional_attention": False,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 262208
|
||||||
|
},
|
||||||
|
"transformers_version": "4.57.3",
|
||||||
|
"vision_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": 896,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"layer_norm_eps": 1e-06,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 14,
|
||||||
|
"vision_use_head": False
|
||||||
|
}
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVGemmaTokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
||||||
|
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
||||||
|
ensuring correct settings and output formatting for downstream consumption.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_path: str, max_length: int = 1024):
|
||||||
|
"""
|
||||||
|
Initialize the tokenizer.
|
||||||
|
Args:
|
||||||
|
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
||||||
|
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
||||||
|
"""
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_path, local_files_only=True, model_max_length=max_length
|
||||||
|
)
|
||||||
|
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
||||||
|
"""
|
||||||
|
Tokenize the given text and return token IDs and attention weights.
|
||||||
|
Args:
|
||||||
|
text (str): The input string to tokenize.
|
||||||
|
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
||||||
|
If False (default), omits the indices.
|
||||||
|
Returns:
|
||||||
|
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
||||||
|
A dictionary with a "gemma" key mapping to:
|
||||||
|
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
||||||
|
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
||||||
|
Example:
|
||||||
|
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
||||||
|
>>> tokenizer.tokenize_with_weights("hello world")
|
||||||
|
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = encoded.input_ids
|
||||||
|
attention_mask = encoded.attention_mask
|
||||||
|
tuples = [
|
||||||
|
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
||||||
|
]
|
||||||
|
out = {"gemma": tuples}
|
||||||
|
|
||||||
|
if not return_word_ids:
|
||||||
|
# Return only (token_id, attention_mask) pairs, omitting token position
|
||||||
|
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature extractor module for Gemma models.
|
||||||
|
This module applies a single linear projection to the input tensor.
|
||||||
|
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||||
|
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||||
|
Attributes:
|
||||||
|
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the GemmaFeaturesExtractorProjLinear module.
|
||||||
|
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the feature extractor.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||||
|
"""
|
||||||
|
return self.aggregate_embed(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
pe: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
|
||||||
|
# 1. Normalization Before Self-Attention
|
||||||
|
norm_hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Self-Attention
|
||||||
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 3. Normalization before Feed-Forward
|
||||||
|
norm_hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||||
|
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||||
|
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
||||||
|
layers, and register usage.
|
||||||
|
Args:
|
||||||
|
attention_head_dim (int): Dimension of each attention head (default=128).
|
||||||
|
num_attention_heads (int): Number of attention heads (default=30).
|
||||||
|
num_layers (int): Number of transformer layers (default=2).
|
||||||
|
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
||||||
|
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
||||||
|
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
||||||
|
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
||||||
|
register replacement. (default=128)
|
||||||
|
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
||||||
|
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attention_head_dim: int = 128,
|
||||||
|
num_attention_heads: int = 30,
|
||||||
|
num_layers: int = 2,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list[int] | None = [4096],
|
||||||
|
causal_temporal_positioning: bool = False,
|
||||||
|
num_learnable_registers: int | None = 128,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||||
|
double_precision_rope: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = (
|
||||||
|
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
||||||
|
)
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
_BasicTransformerBlock1D(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
rope_type=rope_type,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
self.learnable_registers = torch.nn.Parameter(
|
||||||
|
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _replace_padded_with_learnable_registers(
|
||||||
|
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
||||||
|
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
||||||
|
f"{self.num_learnable_registers}."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
||||||
|
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
||||||
|
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
||||||
|
|
||||||
|
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||||
|
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||||
|
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||||
|
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||||
|
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||||
|
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||||
|
|
||||||
|
attention_mask = torch.full_like(
|
||||||
|
attention_mask,
|
||||||
|
0.0,
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of Embeddings1DConnector.
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
||||||
|
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
||||||
|
"""
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
||||||
|
indices_grid = indices_grid[None, None, :]
|
||||||
|
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
||||||
|
freqs_cis = precompute_freqs_cis(
|
||||||
|
indices_grid=indices_grid,
|
||||||
|
dim=self.inner_dim,
|
||||||
|
out_dtype=hidden_states.dtype,
|
||||||
|
theta=self.positional_embedding_theta,
|
||||||
|
max_pos=self.positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
rope_type=self.rope_type,
|
||||||
|
freq_grid_generator=freq_grid_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in self.transformer_1d_blocks:
|
||||||
|
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
||||||
|
|
||||||
|
hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||||
|
def __init__(self,):
|
||||||
|
super().__init__()
|
||||||
|
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||||
|
self.embeddings_connector = Embeddings1DConnector()
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||||
313
diffsynth/models/ltx2_upsampler.py
Normal file
313
diffsynth/models/ltx2_upsampler.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .ltx2_video_vae import LTX2VideoEncoder
|
||||||
|
|
||||||
|
class PixelShuffleND(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
N-dimensional pixel shuffle operation for upsampling tensors.
|
||||||
|
Args:
|
||||||
|
dims (int): Number of dimensions to apply pixel shuffle to.
|
||||||
|
- 1: Temporal (e.g., frames)
|
||||||
|
- 2: Spatial (e.g., height and width)
|
||||||
|
- 3: Spatiotemporal (e.g., depth, height, width)
|
||||||
|
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
|
||||||
|
For dims=1, only the first value is used.
|
||||||
|
For dims=2, the first two values are used.
|
||||||
|
For dims=3, all three values are used.
|
||||||
|
The input tensor is rearranged so that the channel dimension is split into
|
||||||
|
smaller channels and upscaling factors, and the upscaling factors are moved
|
||||||
|
into the corresponding spatial/temporal dimensions.
|
||||||
|
Note:
|
||||||
|
This operation is equivalent to the patchifier operation in for the models. Consider
|
||||||
|
using this class instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||||
|
self.dims = dims
|
||||||
|
self.upscale_factors = upscale_factors
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.dims == 3:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
p3=self.upscale_factors[2],
|
||||||
|
)
|
||||||
|
elif self.dims == 2:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1) f h w -> b c (f p1) h w",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dims: {self.dims}")
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block with two convolutional layers, group normalization, and SiLU activation.
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input and output channels.
|
||||||
|
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
|
||||||
|
if not specified.
|
||||||
|
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
if mid_channels is None:
|
||||||
|
mid_channels = channels
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||||
|
|
||||||
|
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||||
|
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||||
|
self.activation = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.activation(x + residual)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||||
|
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dims in (2, 3)
|
||||||
|
assert isinstance(stride, int)
|
||||||
|
assert stride >= 1
|
||||||
|
assert kernel_size >= 3
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = stride
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||||
|
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||||
|
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||||
|
# The 2D kernel is constructed as the outer product and normalized.
|
||||||
|
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||||
|
k2d = k[:, None] @ k[None, :]
|
||||||
|
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||||
|
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.stride == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
return self._apply_2d(x)
|
||||||
|
else:
|
||||||
|
# dims == 3: apply per-frame on H,W
|
||||||
|
b, _, f, _, _ = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self._apply_2d(x)
|
||||||
|
h2, w2 = x.shape[-2:]
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
|
||||||
|
c = x2d.shape[1]
|
||||||
|
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||||
|
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||||
|
return x2d
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||||
|
if float(scale) not in mapping:
|
||||||
|
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
|
||||||
|
return mapping[float(scale)]
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||||
|
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||||
|
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||||
|
Args:
|
||||||
|
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
||||||
|
scale (`float`): Spatial scaling factor. Supported values are:
|
||||||
|
- 0.75: Downsample by 3/4 (reduce spatial size)
|
||||||
|
- 1.5: Upsample by 3/2 (increase spatial size)
|
||||||
|
- 2.0: Upsample by 2x (double spatial size)
|
||||||
|
- 4.0: Upsample by 4x (quadruple spatial size)
|
||||||
|
Any other value will raise a ValueError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int, scale: float):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num, self.den = _rational_for_scale(self.scale)
|
||||||
|
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||||
|
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, _, f, _, _ = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = self.blur_down(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2LatentUpsampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Model to upsample VAE latents spatially and/or temporally.
|
||||||
|
Args:
|
||||||
|
in_channels (`int`): Number of channels in the input latent
|
||||||
|
mid_channels (`int`): Number of channels in the middle layers
|
||||||
|
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||||
|
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||||
|
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||||
|
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||||
|
spatial_scale (`float`): Scale factor for spatial upsampling
|
||||||
|
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 1024,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
dims: int = 3,
|
||||||
|
spatial_upsample: bool = True,
|
||||||
|
temporal_upsample: bool = False,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
self.num_blocks_per_stage = num_blocks_per_stage
|
||||||
|
self.dims = dims
|
||||||
|
self.spatial_upsample = spatial_upsample
|
||||||
|
self.temporal_upsample = temporal_upsample
|
||||||
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
self.rational_resampler = rational_resampler
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||||
|
|
||||||
|
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||||
|
self.initial_activation = torch.nn.SiLU()
|
||||||
|
|
||||||
|
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||||
|
|
||||||
|
if spatial_upsample and temporal_upsample:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(3),
|
||||||
|
)
|
||||||
|
elif spatial_upsample:
|
||||||
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
|
||||||
|
else:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(2),
|
||||||
|
)
|
||||||
|
elif temporal_upsample:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||||
|
|
||||||
|
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, _, f, _, _ = latent.shape
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
else:
|
||||||
|
x = self.initial_conv(latent)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.temporal_upsample:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
# remove the first frame after upsampling.
|
||||||
|
# This is done because the first frame encodes one pixel frame.
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
elif isinstance(self.upsampler, SpatialRationalResampler):
|
||||||
|
x = self.upsampler(x)
|
||||||
|
else:
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply upsampling to the latent representation using the provided upsampler,
|
||||||
|
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
||||||
|
Args:
|
||||||
|
latent: Input latent tensor of shape [B, C, F, H, W].
|
||||||
|
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
||||||
|
upsampler: LTX2LatentUpsampler module to perform upsampling.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Upsampled and re-normalized latent tensor.
|
||||||
|
"""
|
||||||
|
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
||||||
|
latent = upsampler(latent)
|
||||||
|
latent = video_encoder.per_channel_statistics.normalize(latent)
|
||||||
|
return latent
|
||||||
2317
diffsynth/models/ltx2_video_vae.py
Normal file
2317
diffsynth/models/ltx2_video_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
112
diffsynth/models/model_loader.py
Normal file
112
diffsynth/models/model_loader.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from ..core.loader import load_model, hash_model_file
|
||||||
|
from ..core.vram import AutoWrappedModule
|
||||||
|
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
|
||||||
|
import importlib, json, torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPool:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = []
|
||||||
|
self.model_name = []
|
||||||
|
self.model_path = []
|
||||||
|
|
||||||
|
def import_model_class(self, model_class):
|
||||||
|
split = model_class.rfind(".")
|
||||||
|
model_resource, model_class = model_class[:split], model_class[split+1:]
|
||||||
|
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def need_to_enable_vram_management(self, vram_config):
|
||||||
|
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
|
||||||
|
|
||||||
|
def fetch_module_map(self, model_class, vram_config):
|
||||||
|
if self.need_to_enable_vram_management(vram_config):
|
||||||
|
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
||||||
|
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
|
||||||
|
else:
|
||||||
|
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
||||||
|
else:
|
||||||
|
module_map = None
|
||||||
|
return module_map
|
||||||
|
|
||||||
|
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||||
|
model_class = self.import_model_class(config["model_class"])
|
||||||
|
model_config = config.get("extra_kwargs", {})
|
||||||
|
if "state_dict_converter" in config:
|
||||||
|
state_dict_converter = self.import_model_class(config["state_dict_converter"])
|
||||||
|
else:
|
||||||
|
state_dict_converter = None
|
||||||
|
module_map = self.fetch_module_map(config["model_class"], vram_config)
|
||||||
|
model = load_model(
|
||||||
|
model_class, path, model_config,
|
||||||
|
vram_config["computation_dtype"], vram_config["computation_device"],
|
||||||
|
state_dict_converter,
|
||||||
|
use_disk_map=True,
|
||||||
|
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||||
|
state_dict=state_dict,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def default_vram_config(self):
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": None,
|
||||||
|
"offload_device": None,
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cpu",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cpu",
|
||||||
|
}
|
||||||
|
return vram_config
|
||||||
|
|
||||||
|
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||||
|
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||||
|
if vram_config is None:
|
||||||
|
vram_config = self.default_vram_config()
|
||||||
|
model_hash = hash_model_file(path)
|
||||||
|
loaded = False
|
||||||
|
for config in MODEL_CONFIGS:
|
||||||
|
if config["model_hash"] == model_hash:
|
||||||
|
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||||
|
if clear_parameters: self.clear_parameters(model)
|
||||||
|
self.model.append(model)
|
||||||
|
model_name = config["model_name"]
|
||||||
|
self.model_name.append(model_name)
|
||||||
|
self.model_path.append(path)
|
||||||
|
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
|
||||||
|
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
|
||||||
|
loaded = True
|
||||||
|
if not loaded:
|
||||||
|
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
|
||||||
|
|
||||||
|
def fetch_model(self, model_name, index=None):
|
||||||
|
fetched_models = []
|
||||||
|
fetched_model_paths = []
|
||||||
|
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||||
|
if model_name == model_name_:
|
||||||
|
fetched_models.append(model)
|
||||||
|
fetched_model_paths.append(model_path)
|
||||||
|
if len(fetched_models) == 0:
|
||||||
|
print(f"No {model_name} models available. This is not an error.")
|
||||||
|
model = None
|
||||||
|
elif len(fetched_models) == 1:
|
||||||
|
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||||
|
model = fetched_models[0]
|
||||||
|
else:
|
||||||
|
if index is None:
|
||||||
|
model = fetched_models[0]
|
||||||
|
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||||
|
elif isinstance(index, int):
|
||||||
|
model = fetched_models[:index]
|
||||||
|
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
|
||||||
|
else:
|
||||||
|
model = fetched_models
|
||||||
|
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def clear_parameters(self, model: torch.nn.Module):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
self.clear_parameters(module)
|
||||||
|
for name, param in model.named_parameters(recurse=False):
|
||||||
|
setattr(model, name, None)
|
||||||
@@ -1,441 +0,0 @@
|
|||||||
import os, torch, json, importlib
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
|
||||||
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sd_vae_encoder import SDVAEEncoder
|
|
||||||
from .sd_vae_decoder import SDVAEDecoder
|
|
||||||
from .lora import get_lora_loaders
|
|
||||||
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from .sd3_dit import SD3DiT
|
|
||||||
from .sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from .sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from .sd_controlnet import SDControlNet
|
|
||||||
from .sdxl_controlnet import SDXLControlNetUnion
|
|
||||||
|
|
||||||
from .sd_motion import SDMotionModel
|
|
||||||
from .sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from .svd_image_encoder import SVDImageEncoder
|
|
||||||
from .svd_unet import SVDUNet
|
|
||||||
from .svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from .svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
|
||||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
|
||||||
|
|
||||||
from .flux_dit import FluxDiT
|
|
||||||
from .flux_text_encoder import FluxTextEncoder2
|
|
||||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
|
||||||
from .flux_ipadapter import FluxIpAdapter
|
|
||||||
|
|
||||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
|
||||||
from .cog_dit import CogDiT
|
|
||||||
|
|
||||||
from ..extensions.RIFE import IFNet
|
|
||||||
from ..extensions.ESRGAN import RRDBNet
|
|
||||||
|
|
||||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
|
||||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
|
||||||
state_dict_converter = model_class.state_dict_converter()
|
|
||||||
if model_resource == "civitai":
|
|
||||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
|
||||||
elif model_resource == "diffusers":
|
|
||||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
|
||||||
if isinstance(state_dict_results, tuple):
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results
|
|
||||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
|
||||||
else:
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
|
||||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
|
||||||
with init_weights_on_device():
|
|
||||||
model= model_class(**extra_kwargs)
|
|
||||||
model.load_state_dict(model_state_dict, assign=True)
|
|
||||||
model = model.to(dtype=torch_dtype, device=device)
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
|
||||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
|
||||||
model = model.half()
|
|
||||||
try:
|
|
||||||
model = model.to(device=device)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
|
||||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
|
||||||
base_state_dict = base_model.state_dict()
|
|
||||||
base_model.to("cpu")
|
|
||||||
del base_model
|
|
||||||
model = model_class(**extra_kwargs)
|
|
||||||
model.load_state_dict(base_state_dict, strict=False)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model.to(dtype=torch_dtype, device=device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
while True:
|
|
||||||
for model_id in range(len(model_manager.model)):
|
|
||||||
base_model_name = model_manager.model_name[model_id]
|
|
||||||
if base_model_name == model_name:
|
|
||||||
base_model_path = model_manager.model_path[model_id]
|
|
||||||
base_model = model_manager.model[model_id]
|
|
||||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
|
||||||
patched_model = load_single_patch_model_from_single_file(
|
|
||||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
|
||||||
loaded_model_names.append(base_model_name)
|
|
||||||
loaded_models.append(patched_model)
|
|
||||||
model_manager.model.pop(model_id)
|
|
||||||
model_manager.model_path.pop(model_id)
|
|
||||||
model_manager.model_name.pop(model_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorTemplate:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
self.keys_hash_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
|
||||||
if keys_hash is not None:
|
|
||||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
# Load models without strict matching
|
|
||||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
super().__init__(model_loader_configs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
# Split the state_dict and load from each component
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
valid_state_dict = {}
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
valid_state_dict.update(sub_state_dict)
|
|
||||||
if super().match(file_path, valid_state_dict):
|
|
||||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
else:
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromHuggingfaceFolder:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.architecture_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
|
||||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
return False
|
|
||||||
file_list = os.listdir(file_path)
|
|
||||||
if "config.json" not in file_list:
|
|
||||||
return False
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
if "architectures" not in config and "_class_name" not in config:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
|
||||||
for architecture in architectures:
|
|
||||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
|
||||||
if redirected_architecture is not None:
|
|
||||||
architecture = redirected_architecture
|
|
||||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
|
||||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromPatchedSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device="cuda",
|
|
||||||
model_id_list: List[Preset_model_id] = [],
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
file_path_list: List[str] = [],
|
|
||||||
):
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.device = device
|
|
||||||
self.model = []
|
|
||||||
self.model_path = []
|
|
||||||
self.model_name = []
|
|
||||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
|
||||||
self.model_detector = [
|
|
||||||
ModelDetectorFromSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
|
||||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
|
||||||
]
|
|
||||||
self.load_models(downloaded_files + file_path_list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
|
||||||
print(f"Loading models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
|
||||||
print(f"Loading models from folder: {file_path}")
|
|
||||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
|
||||||
print(f"Loading patch models from file: {file_path}")
|
|
||||||
model_names, models = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following patched models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
|
||||||
if isinstance(file_path, list):
|
|
||||||
for file_path_ in file_path:
|
|
||||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
|
||||||
else:
|
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
|
||||||
for lora in get_lora_loaders():
|
|
||||||
match_results = lora.match(model, state_dict)
|
|
||||||
if match_results is not None:
|
|
||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
|
||||||
lora_prefix, model_resource = match_results
|
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
|
||||||
print(f"Loading models from: {file_path}")
|
|
||||||
if device is None: device = self.device
|
|
||||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
else:
|
|
||||||
state_dict = None
|
|
||||||
for model_detector in self.model_detector:
|
|
||||||
if model_detector.match(file_path, state_dict):
|
|
||||||
model_names, models = model_detector.load(
|
|
||||||
file_path, state_dict,
|
|
||||||
device=device, torch_dtype=torch_dtype,
|
|
||||||
allowed_model_names=model_names, model_manager=self
|
|
||||||
)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f" We cannot detect the model type. No models are loaded.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
|
||||||
for file_path in file_path_list:
|
|
||||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
|
||||||
fetched_models = []
|
|
||||||
fetched_model_paths = []
|
|
||||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
|
||||||
if file_path is not None and file_path != model_path:
|
|
||||||
continue
|
|
||||||
if model_name == model_name_:
|
|
||||||
fetched_models.append(model)
|
|
||||||
fetched_model_paths.append(model_path)
|
|
||||||
if len(fetched_models) == 0:
|
|
||||||
print(f"No {model_name} models available.")
|
|
||||||
return None
|
|
||||||
if len(fetched_models) == 1:
|
|
||||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
|
||||||
else:
|
|
||||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
|
||||||
if require_model_path:
|
|
||||||
return fetched_models[0], fetched_model_paths[0]
|
|
||||||
else:
|
|
||||||
return fetched_models[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.model:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
161
diffsynth/models/nexus_gen.py
Normal file
161
diffsynth/models/nexus_gen.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||||
|
def __init__(self, max_length=1024, max_pixels=262640):
|
||||||
|
super(NexusGenAutoregressiveModel, self).__init__()
|
||||||
|
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers import Qwen2_5_VLConfig
|
||||||
|
self.max_length = max_length
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
model_config = Qwen2_5_VLConfig(**{
|
||||||
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||||
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||||
|
},
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"pad_token_id": 151643,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.49.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_chans": 3,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_processor(self, path):
|
||||||
|
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
||||||
|
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenAutoregressiveModelStateDictConverter()
|
||||||
|
|
||||||
|
def bound_image(self, image, max_pixels=262640):
|
||||||
|
from qwen_vl_utils import smart_resize
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
image.height,
|
||||||
|
image.width,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
return image.resize((resized_width, resized_height))
|
||||||
|
|
||||||
|
def get_editing_msg(self, instruction):
|
||||||
|
if '<image>' not in instruction:
|
||||||
|
instruction = '<image> ' + instruction
|
||||||
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_generation_msg(self, instruction):
|
||||||
|
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||||
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
||||||
|
"""
|
||||||
|
Generate target embeddings for the given instruction and reference image.
|
||||||
|
"""
|
||||||
|
if ref_image is not None:
|
||||||
|
messages = self.get_editing_msg(instruction)
|
||||||
|
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||||
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||||
|
else:
|
||||||
|
messages = self.get_generation_msg(instruction)
|
||||||
|
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||||
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||||
|
|
||||||
|
return output_image_embeddings
|
||||||
|
|
||||||
|
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
||||||
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||||
|
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=images,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = inputs.to(model.device)
|
||||||
|
|
||||||
|
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
||||||
|
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||||
|
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||||
|
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||||
|
|
||||||
|
image_mask = inputs['input_ids'] == model.config.image_token_id
|
||||||
|
indices = image_mask.cumsum(dim=1)
|
||||||
|
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||||
|
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||||
|
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||||
|
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||||
|
|
||||||
|
image_prefill_embeds = model.image_prefill_embeds(
|
||||||
|
torch.arange(81, device=model.device).long()
|
||||||
|
)
|
||||||
|
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||||
|
|
||||||
|
position_ids, _ = model.get_rope_index(
|
||||||
|
inputs['input_ids'],
|
||||||
|
inputs['image_grid_thw'],
|
||||||
|
attention_mask=inputs['attention_mask'])
|
||||||
|
position_ids = position_ids.contiguous()
|
||||||
|
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||||
|
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||||
|
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||||
|
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAutoregressiveModelStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
File diff suppressed because it is too large
Load Diff
417
diffsynth/models/nexus_gen_projector.py
Normal file
417
diffsynth/models/nexus_gen_projector.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
|
mrope_section = mrope_section * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, config, device=None):
|
||||||
|
super().__init__()
|
||||||
|
# BC: "rope_type" was originally "type"
|
||||||
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
|
else:
|
||||||
|
self.rope_type = "default"
|
||||||
|
self.max_seq_len_cached = config.max_position_embeddings
|
||||||
|
self.original_max_seq_len = config.max_position_embeddings
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
from transformers.modeling_rope_utils import _compute_default_rope_parameters
|
||||||
|
self.rope_init_fn = _compute_default_rope_parameters
|
||||||
|
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_frequency_update(self, position_ids, device):
|
||||||
|
"""
|
||||||
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||||
|
1 - growing beyond the cached sequence length (allow scaling)
|
||||||
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||||
|
"""
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||||
|
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
if "dynamic" in self.rope_type:
|
||||||
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||||
|
|
||||||
|
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||||
|
# So we expand the inv_freq to shape (3, ...)
|
||||||
|
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||||
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||||
|
device_type = x.device.type
|
||||||
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
|
||||||
|
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||||
|
cos = cos * self.attention_scaling
|
||||||
|
sin = sin * self.attention_scaling
|
||||||
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLAttention(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.is_causal = True
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.rope_scaling = config.rope_scaling
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
# Fix precision issues in Qwen2-VL float16 inference
|
||||||
|
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||||
|
if query_states.dtype == torch.float16:
|
||||||
|
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||||
|
|
||||||
|
self.mlp = Qwen2MLP(config)
|
||||||
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenImageEmbeddingMerger(nn.Module):
|
||||||
|
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen2_5_VLConfig
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
config = Qwen2_5_VLConfig(**{
|
||||||
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||||
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||||
|
},
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"pad_token_id": 151643,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.49.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_chans": 3,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.config = config
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
|
||||||
|
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
||||||
|
nn.Linear(config.hidden_size, out_channel * expand_ratio),
|
||||||
|
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
|
||||||
|
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
|
||||||
|
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
|
||||||
|
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
|
||||||
|
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
|
||||||
|
|
||||||
|
def get_position_ids(self, image_grid_thw):
|
||||||
|
"""
|
||||||
|
Generates position ids for the input embeddings grid.
|
||||||
|
modified from the qwen2_vl mrope.
|
||||||
|
"""
|
||||||
|
batch_size = image_grid_thw.shape[0]
|
||||||
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[0][0],
|
||||||
|
image_grid_thw[0][1],
|
||||||
|
image_grid_thw[0][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t.item(),
|
||||||
|
h.item() // spatial_merge_size,
|
||||||
|
w.item() // spatial_merge_size,
|
||||||
|
)
|
||||||
|
scale_h = self.base_grid[0][1].item() / h.item()
|
||||||
|
scale_w = self.base_grid[0][2].item() / w.item()
|
||||||
|
|
||||||
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||||
|
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
|
||||||
|
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
|
||||||
|
# 3, B, L
|
||||||
|
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
|
||||||
|
position_ids = self.get_position_ids(embeds_grid)
|
||||||
|
hidden_states = embeds
|
||||||
|
if ref_embeds is not None:
|
||||||
|
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
|
||||||
|
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
|
||||||
|
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, position_embeddings)
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenMergerStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenMergerStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
|
||||||
|
return merger_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAdapter(nn.Module):
|
||||||
|
"""
|
||||||
|
Adapter for Nexus-Gen generation decoder.
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim=3584, output_dim=4096):
|
||||||
|
super(NexusGenAdapter, self).__init__()
|
||||||
|
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
|
||||||
|
nn.LayerNorm(output_dim), nn.ReLU(),
|
||||||
|
nn.Linear(output_dim, output_dim),
|
||||||
|
nn.LayerNorm(output_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.adapter(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
|
||||||
|
return adapter_state_dict
|
||||||
@@ -1,803 +0,0 @@
|
|||||||
# The code is revised from DiT
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
||||||
from transformers import Phi3Config, Phi3Model
|
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Phi3Transformer(Phi3Model):
|
|
||||||
"""
|
|
||||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
|
||||||
We only modified the attention mask
|
|
||||||
Args:
|
|
||||||
config: Phi3Config
|
|
||||||
"""
|
|
||||||
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
|
||||||
"Starts prefetching the next layer cache"
|
|
||||||
with torch.cuda.stream(self.prefetch_stream):
|
|
||||||
# Prefetch next layer tensors to GPU
|
|
||||||
for name, param in self.layers[layer_idx].named_parameters():
|
|
||||||
param.data = param.data.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
def evict_previous_layer(self, layer_idx: int):
|
|
||||||
"Moves the previous layer cache to the CPU"
|
|
||||||
prev_layer_idx = layer_idx - 1
|
|
||||||
for name, param in self.layers[prev_layer_idx].named_parameters():
|
|
||||||
param.data = param.data.to("cpu", non_blocking=True)
|
|
||||||
|
|
||||||
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
|
||||||
# init stream
|
|
||||||
if not hasattr(self, "prefetch_stream"):
|
|
||||||
self.prefetch_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# delete previous layer
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
self.evict_previous_layer(layer_idx)
|
|
||||||
|
|
||||||
# make sure the current layer is ready
|
|
||||||
torch.cuda.synchronize(self.prefetch_stream)
|
|
||||||
|
|
||||||
# load next layer
|
|
||||||
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
offload_model: Optional[bool] = False,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = False
|
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
|
||||||
return_legacy_cache = True
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = DynamicCache()
|
|
||||||
else:
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
logger.warning_once(
|
|
||||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
|
||||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
|
||||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# if inputs_embeds is None:
|
|
||||||
# inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# if cache_position is None:
|
|
||||||
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
||||||
# cache_position = torch.arange(
|
|
||||||
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
||||||
# )
|
|
||||||
# if position_ids is None:
|
|
||||||
# position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 3:
|
|
||||||
dtype = inputs_embeds.dtype
|
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
attention_mask = (1 - attention_mask) * min_dtype
|
|
||||||
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
|
||||||
else:
|
|
||||||
raise Exception("attention_mask parameter was unavailable or invalid")
|
|
||||||
# causal_mask = self._update_causal_mask(
|
|
||||||
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
||||||
# )
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
layer_idx = -1
|
|
||||||
for decoder_layer in self.layers:
|
|
||||||
layer_idx += 1
|
|
||||||
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
decoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
use_cache,
|
|
||||||
cache_position,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if offload_model and not self.training:
|
|
||||||
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
print('************')
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if return_legacy_cache:
|
|
||||||
next_cache = next_cache.to_legacy_cache()
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, shift, scale):
|
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
|
||||||
"""
|
|
||||||
Embeds scalar timesteps into vector representations.
|
|
||||||
"""
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def timestep_embedding(t, dim, max_period=10000):
|
|
||||||
"""
|
|
||||||
Create sinusoidal timestep embeddings.
|
|
||||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
||||||
These may be fractional.
|
|
||||||
:param dim: the dimension of the output.
|
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
|
||||||
:return: an (N, D) Tensor of positional embeddings.
|
|
||||||
"""
|
|
||||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(
|
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
||||||
).to(device=t.device)
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def forward(self, t, dtype=torch.float32):
|
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLayer(nn.Module):
|
|
||||||
"""
|
|
||||||
The final layer of DiT.
|
|
||||||
"""
|
|
||||||
def __init__(self, hidden_size, patch_size, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
|
||||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
||||||
x = modulate(self.norm_final(x), shift, scale)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
|
||||||
"""
|
|
||||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
|
||||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
||||||
"""
|
|
||||||
if isinstance(grid_size, int):
|
|
||||||
grid_size = (grid_size, grid_size)
|
|
||||||
|
|
||||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
|
||||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
|
||||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
||||||
grid = np.stack(grid, axis=0)
|
|
||||||
|
|
||||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
|
||||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
||||||
if cls_token and extra_tokens > 0:
|
|
||||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
|
||||||
return pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
|
|
||||||
# use half of dimensions to encode grid_h
|
|
||||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
||||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
||||||
"""
|
|
||||||
embed_dim: output dimension for each position
|
|
||||||
pos: a list of positions to be encoded: size (M,)
|
|
||||||
out: (M, D)
|
|
||||||
"""
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
|
||||||
omega /= embed_dim / 2.
|
|
||||||
omega = 1. / 10000**omega # (D/2,)
|
|
||||||
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
|
||||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
|
||||||
|
|
||||||
emb_sin = np.sin(out) # (M, D/2)
|
|
||||||
emb_cos = np.cos(out) # (M, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedMR(nn.Module):
|
|
||||||
""" 2D Image to Patch Embedding
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size: int = 2,
|
|
||||||
in_chans: int = 4,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
bias: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenOriginalModel(nn.Module):
|
|
||||||
"""
|
|
||||||
Diffusion model with a Transformer backbone.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
transformer_config: Phi3Config,
|
|
||||||
patch_size=2,
|
|
||||||
in_channels=4,
|
|
||||||
pe_interpolation: float = 1.0,
|
|
||||||
pos_embed_max_size: int = 192,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
|
||||||
|
|
||||||
hidden_size = transformer_config.hidden_size
|
|
||||||
|
|
||||||
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
|
||||||
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
|
||||||
|
|
||||||
self.time_token = TimestepEmbedder(hidden_size)
|
|
||||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
|
||||||
|
|
||||||
self.pe_interpolation = pe_interpolation
|
|
||||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
|
||||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
|
||||||
|
|
||||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
|
||||||
|
|
||||||
self.initialize_weights()
|
|
||||||
|
|
||||||
self.llm = Phi3Transformer(config=transformer_config)
|
|
||||||
self.llm.config.use_cache = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_name):
|
|
||||||
if not os.path.exists(model_name):
|
|
||||||
cache_folder = os.getenv('HF_HUB_CACHE')
|
|
||||||
model_name = snapshot_download(repo_id=model_name,
|
|
||||||
cache_dir=cache_folder,
|
|
||||||
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
|
||||||
config = Phi3Config.from_pretrained(model_name)
|
|
||||||
model = cls(config)
|
|
||||||
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
|
||||||
print("Loading safetensors")
|
|
||||||
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
|
||||||
else:
|
|
||||||
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
|
||||||
model.load_state_dict(ckpt)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
assert not hasattr(self, "llama")
|
|
||||||
|
|
||||||
# Initialize transformer layers:
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
|
||||||
w = self.x_embedder.proj.weight.data
|
|
||||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
||||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
|
||||||
|
|
||||||
w = self.input_x_embedder.proj.weight.data
|
|
||||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
||||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize timestep embedding MLP:
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
|
||||||
|
|
||||||
# Zero-out output layers:
|
|
||||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
|
||||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
|
||||||
|
|
||||||
def unpatchify(self, x, h, w):
|
|
||||||
"""
|
|
||||||
x: (N, T, patch_size**2 * C)
|
|
||||||
imgs: (N, H, W, C)
|
|
||||||
"""
|
|
||||||
c = self.out_channels
|
|
||||||
|
|
||||||
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
|
||||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
|
||||||
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
|
||||||
return imgs
|
|
||||||
|
|
||||||
|
|
||||||
def cropped_pos_embed(self, height, width):
|
|
||||||
"""Crops positional embeddings for SD3 compatibility."""
|
|
||||||
if self.pos_embed_max_size is None:
|
|
||||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
|
||||||
|
|
||||||
height = height // self.patch_size
|
|
||||||
width = width // self.patch_size
|
|
||||||
if height > self.pos_embed_max_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
|
||||||
)
|
|
||||||
if width > self.pos_embed_max_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
top = (self.pos_embed_max_size - height) // 2
|
|
||||||
left = (self.pos_embed_max_size - width) // 2
|
|
||||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
|
||||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
|
||||||
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
|
||||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
|
||||||
return spatial_pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
|
||||||
if isinstance(latents, list):
|
|
||||||
return_list = False
|
|
||||||
if padding_latent is None:
|
|
||||||
padding_latent = [None] * len(latents)
|
|
||||||
return_list = True
|
|
||||||
patched_latents, num_tokens, shapes = [], [], []
|
|
||||||
for latent, padding in zip(latents, padding_latent):
|
|
||||||
height, width = latent.shape[-2:]
|
|
||||||
if is_input_images:
|
|
||||||
latent = self.input_x_embedder(latent)
|
|
||||||
else:
|
|
||||||
latent = self.x_embedder(latent)
|
|
||||||
pos_embed = self.cropped_pos_embed(height, width)
|
|
||||||
latent = latent + pos_embed
|
|
||||||
if padding is not None:
|
|
||||||
latent = torch.cat([latent, padding], dim=-2)
|
|
||||||
patched_latents.append(latent)
|
|
||||||
|
|
||||||
num_tokens.append(pos_embed.size(1))
|
|
||||||
shapes.append([height, width])
|
|
||||||
if not return_list:
|
|
||||||
latents = torch.cat(patched_latents, dim=0)
|
|
||||||
else:
|
|
||||||
latents = patched_latents
|
|
||||||
else:
|
|
||||||
height, width = latents.shape[-2:]
|
|
||||||
if is_input_images:
|
|
||||||
latents = self.input_x_embedder(latents)
|
|
||||||
else:
|
|
||||||
latents = self.x_embedder(latents)
|
|
||||||
pos_embed = self.cropped_pos_embed(height, width)
|
|
||||||
latents = latents + pos_embed
|
|
||||||
num_tokens = latents.size(1)
|
|
||||||
shapes = [height, width]
|
|
||||||
return latents, num_tokens, shapes
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
input_is_list = isinstance(x, list)
|
|
||||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
|
||||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
|
||||||
|
|
||||||
if input_img_latents is not None:
|
|
||||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
|
||||||
if input_ids is not None:
|
|
||||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
|
||||||
input_img_inx = 0
|
|
||||||
for b_inx in input_image_sizes.keys():
|
|
||||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
|
||||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
|
||||||
input_img_inx += 1
|
|
||||||
if input_img_latents is not None:
|
|
||||||
assert input_img_inx == len(input_latents)
|
|
||||||
|
|
||||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
|
||||||
else:
|
|
||||||
input_emb = torch.cat([time_token, x], dim=1)
|
|
||||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
|
||||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
|
||||||
if input_is_list:
|
|
||||||
image_embedding = output[:, -max(num_tokens):]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = []
|
|
||||||
for i in range(x.size(0)):
|
|
||||||
latent = x[i:i+1, :num_tokens[i]]
|
|
||||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
|
||||||
latents.append(latent)
|
|
||||||
else:
|
|
||||||
image_embedding = output[:, -num_tokens:]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
|
||||||
|
|
||||||
if return_past_key_values:
|
|
||||||
return latents, past_key_values
|
|
||||||
return latents
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
||||||
self.llm.config.use_cache = use_kv_cache
|
|
||||||
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
|
||||||
if use_img_cfg:
|
|
||||||
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
|
||||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
||||||
model_out = [cond, cond, cond]
|
|
||||||
else:
|
|
||||||
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
|
||||||
cond = uncond + cfg_scale * (cond - uncond)
|
|
||||||
model_out = [cond, cond]
|
|
||||||
|
|
||||||
return torch.cat(model_out, dim=0), past_key_values
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
||||||
self.llm.config.use_cache = use_kv_cache
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = [None] * len(attention_mask)
|
|
||||||
|
|
||||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
|
||||||
timestep = timestep.to(x[0].dtype)
|
|
||||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
|
||||||
|
|
||||||
model_out, pask_key_values = [], []
|
|
||||||
for i in range(len(input_ids)):
|
|
||||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
|
||||||
model_out.append(temp_out)
|
|
||||||
pask_key_values.append(temp_pask_key_values)
|
|
||||||
|
|
||||||
if len(model_out) == 3:
|
|
||||||
cond, uncond, img_cond = model_out
|
|
||||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
||||||
model_out = [cond, cond, cond]
|
|
||||||
elif len(model_out) == 2:
|
|
||||||
cond, uncond = model_out
|
|
||||||
cond = uncond + cfg_scale * (cond - uncond)
|
|
||||||
model_out = [cond, cond]
|
|
||||||
else:
|
|
||||||
return model_out[0]
|
|
||||||
|
|
||||||
return torch.cat(model_out, dim=0), pask_key_values
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenTransformer(OmniGenOriginalModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = {
|
|
||||||
"_name_or_path": "Phi-3-vision-128k-instruct",
|
|
||||||
"architectures": [
|
|
||||||
"Phi3ForCausalLM"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"hidden_act": "silu",
|
|
||||||
"hidden_size": 3072,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 8192,
|
|
||||||
"max_position_embeddings": 131072,
|
|
||||||
"model_type": "phi3",
|
|
||||||
"num_attention_heads": 32,
|
|
||||||
"num_hidden_layers": 32,
|
|
||||||
"num_key_value_heads": 32,
|
|
||||||
"original_max_position_embeddings": 4096,
|
|
||||||
"rms_norm_eps": 1e-05,
|
|
||||||
"rope_scaling": {
|
|
||||||
"long_factor": [
|
|
||||||
1.0299999713897705,
|
|
||||||
1.0499999523162842,
|
|
||||||
1.0499999523162842,
|
|
||||||
1.0799999237060547,
|
|
||||||
1.2299998998641968,
|
|
||||||
1.2299998998641968,
|
|
||||||
1.2999999523162842,
|
|
||||||
1.4499999284744263,
|
|
||||||
1.5999999046325684,
|
|
||||||
1.6499998569488525,
|
|
||||||
1.8999998569488525,
|
|
||||||
2.859999895095825,
|
|
||||||
3.68999981880188,
|
|
||||||
5.419999599456787,
|
|
||||||
5.489999771118164,
|
|
||||||
5.489999771118164,
|
|
||||||
9.09000015258789,
|
|
||||||
11.579999923706055,
|
|
||||||
15.65999984741211,
|
|
||||||
15.769999504089355,
|
|
||||||
15.789999961853027,
|
|
||||||
18.360000610351562,
|
|
||||||
21.989999771118164,
|
|
||||||
23.079999923706055,
|
|
||||||
30.009998321533203,
|
|
||||||
32.35000228881836,
|
|
||||||
32.590003967285156,
|
|
||||||
35.56000518798828,
|
|
||||||
39.95000457763672,
|
|
||||||
53.840003967285156,
|
|
||||||
56.20000457763672,
|
|
||||||
57.95000457763672,
|
|
||||||
59.29000473022461,
|
|
||||||
59.77000427246094,
|
|
||||||
59.920005798339844,
|
|
||||||
61.190006256103516,
|
|
||||||
61.96000671386719,
|
|
||||||
62.50000762939453,
|
|
||||||
63.3700065612793,
|
|
||||||
63.48000717163086,
|
|
||||||
63.48000717163086,
|
|
||||||
63.66000747680664,
|
|
||||||
63.850006103515625,
|
|
||||||
64.08000946044922,
|
|
||||||
64.760009765625,
|
|
||||||
64.80001068115234,
|
|
||||||
64.81001281738281,
|
|
||||||
64.81001281738281
|
|
||||||
],
|
|
||||||
"short_factor": [
|
|
||||||
1.05,
|
|
||||||
1.05,
|
|
||||||
1.05,
|
|
||||||
1.1,
|
|
||||||
1.1,
|
|
||||||
1.1,
|
|
||||||
1.2500000000000002,
|
|
||||||
1.2500000000000002,
|
|
||||||
1.4000000000000004,
|
|
||||||
1.4500000000000004,
|
|
||||||
1.5500000000000005,
|
|
||||||
1.8500000000000008,
|
|
||||||
1.9000000000000008,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.1000000000000005,
|
|
||||||
2.1000000000000005,
|
|
||||||
2.2,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3999999999999995,
|
|
||||||
2.3999999999999995,
|
|
||||||
2.6499999999999986,
|
|
||||||
2.6999999999999984,
|
|
||||||
2.8999999999999977,
|
|
||||||
2.9499999999999975,
|
|
||||||
3.049999999999997,
|
|
||||||
3.049999999999997,
|
|
||||||
3.049999999999997
|
|
||||||
],
|
|
||||||
"type": "su"
|
|
||||||
},
|
|
||||||
"rope_theta": 10000.0,
|
|
||||||
"sliding_window": 131072,
|
|
||||||
"tie_word_embeddings": False,
|
|
||||||
"torch_dtype": "bfloat16",
|
|
||||||
"transformers_version": "4.38.1",
|
|
||||||
"use_cache": True,
|
|
||||||
"vocab_size": 32064,
|
|
||||||
"_attn_implementation": "sdpa"
|
|
||||||
}
|
|
||||||
config = Phi3Config(**config)
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
|
||||||
input_is_list = isinstance(x, list)
|
|
||||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
|
||||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
|
||||||
|
|
||||||
if input_img_latents is not None:
|
|
||||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
|
||||||
if input_ids is not None:
|
|
||||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
|
||||||
input_img_inx = 0
|
|
||||||
for b_inx in input_image_sizes.keys():
|
|
||||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
|
||||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
|
||||||
input_img_inx += 1
|
|
||||||
if input_img_latents is not None:
|
|
||||||
assert input_img_inx == len(input_latents)
|
|
||||||
|
|
||||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
|
||||||
else:
|
|
||||||
input_emb = torch.cat([time_token, x], dim=1)
|
|
||||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
|
||||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
|
||||||
if input_is_list:
|
|
||||||
image_embedding = output[:, -max(num_tokens):]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = []
|
|
||||||
for i in range(x.size(0)):
|
|
||||||
latent = x[i:i+1, :num_tokens[i]]
|
|
||||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
|
||||||
latents.append(latent)
|
|
||||||
else:
|
|
||||||
image_embedding = output[:, -num_tokens:]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
|
||||||
|
|
||||||
if return_past_key_values:
|
|
||||||
return latents, past_key_values
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
||||||
self.llm.config.use_cache = use_kv_cache
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = [None] * len(attention_mask)
|
|
||||||
|
|
||||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
|
||||||
timestep = timestep.to(x[0].dtype)
|
|
||||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
|
||||||
|
|
||||||
model_out, pask_key_values = [], []
|
|
||||||
for i in range(len(input_ids)):
|
|
||||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
|
||||||
model_out.append(temp_out)
|
|
||||||
pask_key_values.append(temp_pask_key_values)
|
|
||||||
|
|
||||||
if len(model_out) == 3:
|
|
||||||
cond, uncond, img_cond = model_out
|
|
||||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
||||||
model_out = [cond, cond, cond]
|
|
||||||
elif len(model_out) == 2:
|
|
||||||
cond, uncond = model_out
|
|
||||||
cond = uncond + cfg_scale * (cond - uncond)
|
|
||||||
model_out = [cond, cond]
|
|
||||||
else:
|
|
||||||
return model_out[0]
|
|
||||||
|
|
||||||
return torch.cat(model_out, dim=0), pask_key_values
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return OmniGenTransformerStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenTransformerStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user