mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
589 Commits
wan-lora-f
...
docs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99ed3f8a97 | ||
|
|
ac13af2c1a | ||
|
|
b3b63fef3e | ||
|
|
26fee43ef4 | ||
|
|
344a287bcb | ||
|
|
5e69337d59 | ||
|
|
07f8f485ed | ||
|
|
b5acef9e74 | ||
|
|
4681cffa35 | ||
|
|
412b7cbea0 | ||
|
|
2c4f743c0f | ||
|
|
f6430c5882 | ||
|
|
1f972bcafb | ||
|
|
7993296a90 | ||
|
|
71206ded00 | ||
|
|
f6d85f3c2e | ||
|
|
2f22e598b7 | ||
|
|
888caf8b88 | ||
|
|
b6e39c97af | ||
|
|
02124c4034 | ||
|
|
fddc98ff16 | ||
|
|
0dfcd25cf3 | ||
|
|
ff10fde47f | ||
|
|
dc94614c80 | ||
|
|
e56a4d5730 | ||
|
|
3f8468893a | ||
|
|
1b47e1dc22 | ||
|
|
b0bf78e915 | ||
|
|
abdf66d09e | ||
|
|
27b1fe240b | ||
|
|
1635897516 | ||
|
|
8d172127cd | ||
|
|
fccb1ecdd7 | ||
|
|
c0f7e1db7c | ||
|
|
53890bafa4 | ||
|
|
6886f7ba35 | ||
|
|
afd48cd706 | ||
|
|
24b68c2392 | ||
|
|
280ff7cca6 | ||
|
|
b4b62e2f7c | ||
|
|
051b957adb | ||
|
|
ca9b5e64ea | ||
|
|
6d1be405b9 | ||
|
|
25c3a3d3e2 | ||
|
|
49bc84f78e | ||
|
|
25a9e75030 | ||
|
|
2a7ac73eb5 | ||
|
|
f4f991d409 | ||
|
|
a781138413 | ||
|
|
91a5623976 | ||
|
|
28cd355aba | ||
|
|
005389fca7 | ||
|
|
a6282056eb | ||
|
|
21a6eb8e2f | ||
|
|
98ab238340 | ||
|
|
2070bbd925 | ||
|
|
1c8a0f8317 | ||
|
|
9f07d65ebb | ||
|
|
5f1d5adfce | ||
|
|
4f23caa55f | ||
|
|
b4f6a4de6c | ||
|
|
53fe42af1b | ||
|
|
ee9a3b4405 | ||
|
|
b1a2782ad7 | ||
|
|
8d303b47e9 | ||
|
|
00da4b6c4f | ||
|
|
22695e9be0 | ||
|
|
3140199c96 | ||
|
|
98290190ec | ||
|
|
3f4de2cc7f | ||
|
|
8d0df403ca | ||
|
|
4e9db263b0 | ||
|
|
d12bf71bcc | ||
|
|
35e0776022 | ||
|
|
ffb7a138f7 | ||
|
|
548304667f | ||
|
|
273143136c | ||
|
|
030ebe649a | ||
|
|
90921d2293 | ||
|
|
b61131c693 | ||
|
|
37fbb3248a | ||
|
|
d13f533f42 | ||
|
|
b3cc652dea | ||
|
|
d879d66c62 | ||
|
|
848bfd6993 | ||
|
|
269da09f6e | ||
|
|
e30514a00c | ||
|
|
3743b1307c | ||
|
|
a835df984c | ||
|
|
3e4b47e424 | ||
|
|
dd8d902624 | ||
|
|
a8b340c098 | ||
|
|
88497b5c13 | ||
|
|
1e90c72d94 | ||
|
|
3dd82a738e | ||
|
|
8ad2d9884b | ||
|
|
70f531b724 | ||
|
|
37c2868b61 | ||
|
|
a18e6233b5 | ||
|
|
2336d5f6b3 | ||
|
|
b6ccb362b9 | ||
|
|
ae52d93694 | ||
|
|
ad91d41601 | ||
|
|
dce77ec4d1 | ||
|
|
5c0b07d939 | ||
|
|
19e429d889 | ||
|
|
209a350c0f | ||
|
|
a3c2744a43 | ||
|
|
55e8346da3 | ||
|
|
b7979b2633 | ||
|
|
c90aaa2798 | ||
|
|
0c617d5d9e | ||
|
|
fd87b72754 | ||
|
|
db75508ba0 | ||
|
|
acba342a63 | ||
|
|
d16877e695 | ||
|
|
e99cdcf3b8 | ||
|
|
a236a17f17 | ||
|
|
03e530dc39 | ||
|
|
6be244233a | ||
|
|
544c391936 | ||
|
|
f4d06ce3fc | ||
|
|
ffedb9eb52 | ||
|
|
381067515c | ||
|
|
00f2d1aa5d | ||
|
|
8cc3bece6d | ||
|
|
f4bf592064 | ||
|
|
3235393fb5 | ||
|
|
3b662da31e | ||
|
|
19ce3048c1 | ||
|
|
de0aa946f7 | ||
|
|
f376202a49 | ||
|
|
a13ecfc46b | ||
|
|
10a1853eda | ||
|
|
0efab85674 | ||
|
|
f45a0ffd02 | ||
|
|
8ba528a8f6 | ||
|
|
dd479e5bff | ||
|
|
bac39b1cd2 | ||
|
|
c1c9a4853b | ||
|
|
3ee5f53a36 | ||
|
|
32449a6aa0 | ||
|
|
a6884f6b3a | ||
|
|
b078666640 | ||
|
|
7604ca1e52 | ||
|
|
62c3d406d9 | ||
|
|
5745c9f200 | ||
|
|
86829120c2 | ||
|
|
60ac96525b | ||
|
|
07b1f5702f | ||
|
|
507e7e5d36 | ||
|
|
ab8580f77e | ||
|
|
6454259853 | ||
|
|
9cc1697d4d | ||
|
|
c758769a02 | ||
|
|
a5935e973a | ||
|
|
9834d72e4d | ||
|
|
01234e59c0 | ||
|
|
8f1d10fb43 | ||
|
|
20e1aaf908 | ||
|
|
c6722b3f56 | ||
|
|
11315d7a40 | ||
|
|
68d97a9844 | ||
|
|
4629d4cf9e | ||
|
|
3cb5cec906 | ||
|
|
b7e16b9034 | ||
|
|
83d1e7361f | ||
|
|
1547c3f786 | ||
|
|
bfaaf12bf4 | ||
|
|
47545e1aab | ||
|
|
7c6905a432 | ||
|
|
2883bc1b76 | ||
|
|
78d8842ddf | ||
|
|
5821a664a0 | ||
|
|
ab9aa1a087 | ||
|
|
a4d34d9f3d | ||
|
|
127cc9007a | ||
|
|
e1f5db5f5c | ||
|
|
e316fb717f | ||
|
|
64c5139502 | ||
|
|
5da9611a74 | ||
|
|
733750d01b | ||
|
|
edc95359d0 | ||
|
|
f2d0241e26 | ||
|
|
7b5d7f4af5 | ||
|
|
1fa9a6c60c | ||
|
|
51efa128d3 | ||
|
|
421c6a5fce | ||
|
|
864080d8f2 | ||
|
|
ba372dd295 | ||
|
|
1ceb02f673 | ||
|
|
30f93161fb | ||
|
|
3ee3cc3104 | ||
|
|
c2218f5c73 | ||
|
|
72af7122b3 | ||
|
|
afd101f345 | ||
|
|
1313f4dd63 | ||
|
|
8332ecebb7 | ||
|
|
401d7d74a5 | ||
|
|
b8d7d55568 | ||
|
|
a30ed9093f | ||
|
|
b73e713028 | ||
|
|
e0eabaa426 | ||
|
|
538017177a | ||
|
|
30292d9411 | ||
|
|
b168d7aa8b | ||
|
|
8ea45b0daa | ||
|
|
0a1c172a00 | ||
|
|
77fac2a03f | ||
|
|
084bc2fc78 | ||
|
|
c63d474b60 | ||
|
|
7540568156 | ||
|
|
c5d426c254 | ||
|
|
a36f2f6032 | ||
|
|
ed256ef8be | ||
|
|
15079a6cb8 | ||
|
|
c084d6377b | ||
|
|
e9bc42f233 | ||
|
|
0d6de58af9 | ||
|
|
acbf932974 | ||
|
|
9d64ed7042 | ||
|
|
0b4b337e9a | ||
|
|
99908d9a1c | ||
|
|
73ced7a46d | ||
|
|
32b8b9b51e | ||
|
|
f6534a5b63 | ||
|
|
034c9b6c60 | ||
|
|
76335e0fe5 | ||
|
|
c0b589d934 | ||
|
|
833ba1e1fa | ||
|
|
7a5974d964 | ||
|
|
b0abdaffb4 | ||
|
|
e9f29bc402 | ||
|
|
1a7f482fbd | ||
|
|
3a0d51d100 | ||
|
|
bffdb901ed | ||
|
|
d93e8738cd | ||
|
|
7e5ce5d5c9 | ||
|
|
7aef554d83 | ||
|
|
090074e395 | ||
|
|
2dcdeefca8 | ||
|
|
452a6ca5cf | ||
|
|
d6cf20ef33 | ||
|
|
efdd6a59b6 | ||
|
|
42ec7b08eb | ||
|
|
d049fb6d1d | ||
|
|
144365b07d | ||
|
|
cb8de6be1b | ||
|
|
8c13362dcf | ||
|
|
c13fd7e0ee | ||
|
|
958ebf1352 | ||
|
|
b6da77e468 | ||
|
|
260e32217f | ||
|
|
5cee326f92 | ||
|
|
1d240994e7 | ||
|
|
a0bae07825 | ||
|
|
ff71720297 | ||
|
|
dea85643e6 | ||
|
|
6a46f32afe | ||
|
|
4641d0f360 | ||
|
|
826bab5962 | ||
|
|
5b6d112c15 | ||
|
|
febdaf6067 | ||
|
|
0a78bb9d38 | ||
|
|
9cea10cc69 | ||
|
|
caa17da5b9 | ||
|
|
fdeb363fa2 | ||
|
|
4147473c81 | ||
|
|
8a0bd7c377 | ||
|
|
b541b9bed2 | ||
|
|
419d47c195 | ||
|
|
ac2e859960 | ||
|
|
6663dca015 | ||
|
|
86e509ad31 | ||
|
|
8fcfa1dd2d | ||
|
|
2b7a2548b4 | ||
|
|
f0916e6bae | ||
|
|
822e80ec2f | ||
|
|
04e39f7de5 | ||
|
|
ce0b948655 | ||
|
|
c795e35142 | ||
|
|
f7c01f1367 | ||
|
|
cb49f0283f | ||
|
|
6a45815b23 | ||
|
|
8dae8d7bc8 | ||
|
|
f6418004bb | ||
|
|
c4b97cd591 | ||
|
|
b6d1ff01e0 | ||
|
|
0d81626fe7 | ||
|
|
e3f47a799b | ||
|
|
e014cad820 | ||
|
|
89bf3ce5cf | ||
|
|
3ebe118f23 | ||
|
|
7f719cefe6 | ||
|
|
46bd05b54d | ||
|
|
613dafbd09 | ||
|
|
952933eeb1 | ||
|
|
c0172e70b1 | ||
|
|
6ab426e641 | ||
|
|
d0467a7e8d | ||
|
|
36838a05ee | ||
|
|
5e6f9f89f1 | ||
|
|
2dad9a319c | ||
|
|
9ec0652339 | ||
|
|
7e348083ae | ||
|
|
29b12b2f4e | ||
|
|
b3f57ed920 | ||
|
|
c9fea729d8 | ||
|
|
9d0683df25 | ||
|
|
838b8109b1 | ||
|
|
3a9621f6da | ||
|
|
fff2c89360 | ||
|
|
ce61bef2b0 | ||
|
|
123f6dbadb | ||
|
|
f9ce261a0e | ||
|
|
d93de98a21 | ||
|
|
ad1da43476 | ||
|
|
398b1dbd7a | ||
|
|
9f6922bba9 | ||
|
|
f11a91e610 | ||
|
|
7ed09bb78d | ||
|
|
ac931856d5 | ||
|
|
2d09318236 | ||
|
|
7dc49bd036 | ||
|
|
4d16bdf853 | ||
|
|
01a1f48f70 | ||
|
|
6a9d875d65 | ||
|
|
f1c96d31b4 | ||
|
|
aafcca8d77 | ||
|
|
bf369cad4d | ||
|
|
024fdad76d | ||
|
|
e1c2eda5f5 | ||
|
|
0b574cc0c2 | ||
|
|
3212c83398 | ||
|
|
49f9a11eb3 | ||
|
|
fa36739f01 | ||
|
|
42e9764b60 | ||
|
|
f7f5c07570 | ||
|
|
ec1a936624 | ||
|
|
6e6136586c | ||
|
|
34766863f8 | ||
|
|
1d76d5e828 | ||
|
|
250540a398 | ||
|
|
46f3c38c37 | ||
|
|
9a8982efb1 | ||
|
|
3c815cce4b | ||
|
|
39d199c8bb | ||
|
|
f5506d1e13 | ||
|
|
166a8734fe | ||
|
|
b2273ec568 | ||
|
|
89c4e3bdb6 | ||
|
|
051ebf3439 | ||
|
|
7cfadc2ca8 | ||
|
|
32cf5d32ce | ||
|
|
4f7c3b6a1e | ||
|
|
57128dc89f | ||
|
|
d20680baae | ||
|
|
970403f78e | ||
|
|
bee2a969e5 | ||
|
|
2803ffcb38 | ||
|
|
d3224e1fdc | ||
|
|
3c2f85606f | ||
|
|
1f25ad416b | ||
|
|
d0b9b25db7 | ||
|
|
ef09db69cd | ||
|
|
84ede171fd | ||
|
|
6f4e38276e | ||
|
|
a3b67436a6 | ||
|
|
829ca3414b | ||
|
|
3915bc3ee6 | ||
|
|
4299c999b5 | ||
|
|
6bae70eee0 | ||
|
|
6452edb738 | ||
|
|
bc739c78cd | ||
|
|
2feaeb1a64 | ||
|
|
09360cf4f5 | ||
|
|
26461c1963 | ||
|
|
0412fc7232 | ||
|
|
8d2f6ad32e | ||
|
|
1625894694 | ||
|
|
c35f2d8bda | ||
|
|
a8ee7ec9ef | ||
|
|
46d390cf8a | ||
|
|
6b8e3880ff | ||
|
|
c1c3be2420 | ||
|
|
b2554db100 | ||
|
|
b63f81c6e3 | ||
|
|
cb2caa3a36 | ||
|
|
f0ea049faa | ||
|
|
0954e8a017 | ||
|
|
e4178e2501 | ||
|
|
0b860abf1b | ||
|
|
8c558b3526 | ||
|
|
aef982a53c | ||
|
|
db124fa6bc | ||
|
|
2ed3860085 | ||
|
|
87ab7d020b | ||
|
|
03c8fd5e61 | ||
|
|
9c51623fc2 | ||
|
|
8ec545d70c | ||
|
|
79fa8607dc | ||
|
|
7df48fc2b5 | ||
|
|
8ef91b3672 | ||
|
|
2860470b4e | ||
|
|
c125728ce0 | ||
|
|
63eaa9e7ea | ||
|
|
158567ca20 | ||
|
|
de4e2703ca | ||
|
|
9e683bfe25 | ||
|
|
0befa05014 | ||
|
|
283f35447a | ||
|
|
c35414a652 | ||
|
|
68aafab09e | ||
|
|
29663b25a6 | ||
|
|
2861ec4d9f | ||
|
|
729c512c66 | ||
|
|
2af3a6f6a2 | ||
|
|
05dba91f79 | ||
|
|
b8f05bb342 | ||
|
|
5f68727ad3 | ||
|
|
bba44173d2 | ||
|
|
9015d08927 | ||
|
|
1dfa32f0ae | ||
|
|
c98e31fee3 | ||
|
|
f3d2470e84 | ||
|
|
4ad6bd4e23 | ||
|
|
3aed244c6f | ||
|
|
783c435d88 | ||
|
|
cd1ba7281b | ||
|
|
970ff12ff5 | ||
|
|
2827b60330 | ||
|
|
b3df7e5e21 | ||
|
|
c18b5a0c71 | ||
|
|
b9f7d08219 | ||
|
|
11ea986e67 | ||
|
|
b06066f25b | ||
|
|
0b3400bca3 | ||
|
|
0d509241c0 | ||
|
|
ebeda32215 | ||
|
|
ff95c56884 | ||
|
|
2871535f3b | ||
|
|
e3c5d2540b | ||
|
|
22705a44b4 | ||
|
|
43a8d9768c | ||
|
|
dbee3a1ae0 | ||
|
|
f1f00c4255 | ||
|
|
c05b1a2fd0 | ||
|
|
55951590f5 | ||
|
|
1384de0353 | ||
|
|
05c6b49b90 | ||
|
|
d19fcc8c04 | ||
|
|
af6b1d4246 | ||
|
|
cbd10fb27d | ||
|
|
836fa5c957 | ||
|
|
dc066aca2d | ||
|
|
44f6ffbf56 | ||
|
|
0a24d0819f | ||
|
|
f0106cd48c | ||
|
|
dee4075380 | ||
|
|
a692389df0 | ||
|
|
629e9be4ce | ||
|
|
3a3d9010b8 | ||
|
|
a25334b352 | ||
|
|
00279a8375 | ||
|
|
89397c755a | ||
|
|
77676b5cea | ||
|
|
0f4b08daa3 | ||
|
|
63b2c51e11 | ||
|
|
8a9dbbd3ba | ||
|
|
22d28665fe | ||
|
|
1363a0559f | ||
|
|
9cb887015b | ||
|
|
789dade026 | ||
|
|
9bb51fe879 | ||
|
|
d9c812818d | ||
|
|
c8e9a96196 | ||
|
|
6143af4654 | ||
|
|
9458e382b0 | ||
|
|
4f2d9226cf | ||
|
|
f688a469b1 | ||
|
|
c8ea3b3356 | ||
|
|
6e9472b470 | ||
|
|
a5c03c5272 | ||
|
|
8068ac2592 | ||
|
|
5f80e7ac5e | ||
|
|
157e0be49d | ||
|
|
3dbe271aab | ||
|
|
44e2eecdf1 | ||
|
|
8c226e83a6 | ||
|
|
009f26bb40 | ||
|
|
fcf2fbc07f | ||
|
|
b603acd36a | ||
|
|
6c8bb6438b | ||
|
|
8072d3839d | ||
|
|
c8ad643374 | ||
|
|
31f9df5e62 | ||
|
|
e2f415524a | ||
|
|
3eb7e7530e | ||
|
|
916aa54595 | ||
|
|
6ddbd43f7b | ||
|
|
a37a83ecc3 | ||
|
|
f2a0d0c85f | ||
|
|
93194f44e8 | ||
|
|
c4e5033532 | ||
|
|
cc6cd26733 | ||
|
|
1113d305d1 | ||
|
|
6d5f8b7423 | ||
|
|
1b3c204d20 | ||
|
|
1788d50f0a | ||
|
|
e7a21dbf0b | ||
|
|
3b3e1e4d44 | ||
|
|
24426e3a32 | ||
|
|
31369bab15 | ||
|
|
551721658b | ||
|
|
46f052375f | ||
|
|
c2d35a2157 | ||
|
|
4c052e42bc | ||
|
|
a88613555d | ||
|
|
c164519ef1 | ||
|
|
afff5ffb21 | ||
|
|
a8481fd5e1 | ||
|
|
8584e50309 | ||
|
|
9f3e02f167 | ||
|
|
7ad9b9aecc | ||
|
|
b6a111d3a2 | ||
|
|
bd6f2695a9 | ||
|
|
6eecc9d442 | ||
|
|
35269783d7 | ||
|
|
9534a78167 | ||
|
|
830b1b7202 | ||
|
|
436a91e0c9 | ||
|
|
40760ab88b | ||
|
|
8badd63a2d | ||
|
|
b1afff1728 | ||
|
|
6e977e1181 | ||
|
|
62f6ca2b8a | ||
|
|
4e00c109e3 | ||
|
|
8f10a9c353 | ||
|
|
a3a35acc7e | ||
|
|
675eefa07e | ||
|
|
dbef6122e9 | ||
|
|
d150bcf622 | ||
|
|
451aab0116 | ||
|
|
3edf3583b1 | ||
|
|
ef2a7abad4 | ||
|
|
32f630ff5f | ||
|
|
109a0a0d49 | ||
|
|
4f01b37a2a | ||
|
|
cc6306136c | ||
|
|
419ace37f3 | ||
|
|
ccf24c363f | ||
|
|
b7a1ac6671 | ||
|
|
e54c0a8468 | ||
|
|
5f4cb32255 | ||
|
|
7b6cf39618 | ||
|
|
bf81de0c88 | ||
|
|
b36cad6929 | ||
|
|
b161bd6dfd | ||
|
|
538cfcbb77 | ||
|
|
a4105d2c0e | ||
|
|
553b341f5f | ||
|
|
e9e24b8cf1 | ||
|
|
1b693d0028 | ||
|
|
a4c3c07229 | ||
|
|
6b24748c80 | ||
|
|
8f2f8646eb | ||
|
|
e3ac438f5a | ||
|
|
b731628112 | ||
|
|
0dc56d9dcc | ||
|
|
b925b402e2 | ||
|
|
61d9653536 | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b |
BIN
.github/workflows/logo.gif
vendored
Normal file
BIN
.github/workflows/logo.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 146 KiB |
4
.github/workflows/publish.yaml
vendored
4
.github/workflows/publish.yaml
vendored
@@ -20,9 +20,9 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel && pip install -r requirements.txt
|
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||||
- name: Build DiffSynth
|
- name: Build DiffSynth
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python -m build
|
||||||
- name: Publish package to PyPI
|
- name: Publish package to PyPI
|
||||||
run: |
|
run: |
|
||||||
pip install twine
|
pip install twine
|
||||||
|
|||||||
175
.gitignore
vendored
Normal file
175
.gitignore
vendored
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
/data
|
||||||
|
/models
|
||||||
|
/scripts
|
||||||
|
/diffusers
|
||||||
|
*.pkl
|
||||||
|
*.safetensors
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
*.pt
|
||||||
|
*.bin
|
||||||
|
*.DS_Store
|
||||||
|
*.msc
|
||||||
|
*.mv
|
||||||
|
log*.txt
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
922
README_zh.md
Normal file
922
README_zh.md
Normal file
@@ -0,0 +1,922 @@
|
|||||||
|
# DiffSynth-Studio
|
||||||
|
|
||||||
|
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||||
|
|
||||||
|
[](https://pypi.org/project/DiffSynth/)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||||
|
|
||||||
|
[Switch to English](./README.md)
|
||||||
|
|
||||||
|
## 简介
|
||||||
|
|
||||||
|
> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
|
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||||
|
|
||||||
|
DiffSynth 目前包括两个开源项目:
|
||||||
|
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
|
||||||
|
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
|
||||||
|
|
||||||
|
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
|
||||||
|
|
||||||
|
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
|
||||||
|
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
||||||
|
|
||||||
|
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
|
||||||
|
|
||||||
|
## 更新历史
|
||||||
|
|
||||||
|
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
||||||
|
|
||||||
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
|
||||||
|
|
||||||
|
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
||||||
|
|
||||||
|
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
|
||||||
|
|
||||||
|
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
||||||
|
|
||||||
|
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
||||||
|
|
||||||
|
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
|
||||||
|
|
||||||
|
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
|
||||||
|
|
||||||
|
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
||||||
|
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
|
||||||
|
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
|
||||||
|
- 新模型支持
|
||||||
|
- Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
|
||||||
|
- FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
|
||||||
|
- 训练框架升级
|
||||||
|
- [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型),在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
|
||||||
|
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
|
||||||
|
- [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>更多</summary>
|
||||||
|
|
||||||
|
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
||||||
|
|
||||||
|
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
||||||
|
|
||||||
|
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
|
||||||
|
|
||||||
|
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
||||||
|
|
||||||
|
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||||
|
|
||||||
|
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
||||||
|
|
||||||
|
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
||||||
|
|
||||||
|
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
||||||
|
|
||||||
|
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
|
||||||
|
|
||||||
|
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
||||||
|
|
||||||
|
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
||||||
|
|
||||||
|
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
|
||||||
|
|
||||||
|
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
|
||||||
|
|
||||||
|
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
|
||||||
|
|
||||||
|
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||||
|
|
||||||
|
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||||
|
|
||||||
|
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
|
||||||
|
|
||||||
|
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
|
||||||
|
|
||||||
|
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
|
||||||
|
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||||
|
- Github 仓库: https://github.com/modelscope/Nexus-Gen
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||||
|
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||||
|
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||||
|
|
||||||
|
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
|
||||||
|
|
||||||
|
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
|
||||||
|
|
||||||
|
- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
|
||||||
|
|
||||||
|
- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||||
|
|
||||||
|
- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
|
||||||
|
|
||||||
|
- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
|
||||||
|
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||||
|
|
||||||
|
- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
|
||||||
|
- 论文: https://arxiv.org/abs/2412.12888
|
||||||
|
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
|
||||||
|
|
||||||
|
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
|
||||||
|
|
||||||
|
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
|
||||||
|
|
||||||
|
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
|
||||||
|
- 文本到视频
|
||||||
|
- 视频编辑
|
||||||
|
- 自我超分
|
||||||
|
- 视频插帧
|
||||||
|
|
||||||
|
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
|
||||||
|
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
|
||||||
|
|
||||||
|
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
|
||||||
|
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
|
||||||
|
- LoRA、ControlNet 和其他附加模型将很快推出。
|
||||||
|
|
||||||
|
- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
|
||||||
|
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
|
||||||
|
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
|
||||||
|
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
|
||||||
|
|
||||||
|
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
|
||||||
|
|
||||||
|
- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- 源代码已在此项目中发布。
|
||||||
|
- 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
|
||||||
|
|
||||||
|
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
|
||||||
|
|
||||||
|
- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
|
||||||
|
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
|
||||||
|
- 演示视频已在 Bilibili 上展示,包含三个任务:
|
||||||
|
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||||
|
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||||
|
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||||
|
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
|
||||||
|
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
|
||||||
|
|
||||||
|
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
|
||||||
|
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
|
||||||
|
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
|
||||||
|
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
|
||||||
|
- 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
|
||||||
|
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
|
||||||
|
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
|
||||||
|
|
||||||
|
- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
|
||||||
|
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
|
||||||
|
- 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
从源码安装(推荐):
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 基础框架
|
||||||
|
|
||||||
|
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>环境变量配置</summary>
|
||||||
|
|
||||||
|
> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
|
||||||
|
>
|
||||||
|
> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> import os
|
||||||
|
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 图像生成模型
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### Z-Image:[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = Flux2ImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
||||||
|
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||||
|
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||||
|
Qwen/Qwen-Image-->EliGen-Series;
|
||||||
|
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||||
|
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||||
|
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||||
|
Qwen/Qwen-Image-->Distill-Series;
|
||||||
|
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||||
|
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||||
|
Qwen/Qwen-Image-->ControlNet-Series;
|
||||||
|
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||||
|
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||||
|
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e4m3fn,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
||||||
|
)
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
image = pipe(prompt=prompt, seed=0)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||||
|
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||||
|
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||||
|
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||||
|
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||||
|
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||||
|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||||
|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||||
|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||||
|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||||
|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||||
|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||||
|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||||
|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||||
|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||||
|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 视频生成模型
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
|
|
||||||
|
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
save_video(video, "video.mp4", fps=15, quality=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
Wan-Series-->Wan2.1-Series;
|
||||||
|
Wan-Series-->Wan2.2-Series;
|
||||||
|
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||||
|
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||||
|
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||||
|
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||||
|
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||||
|
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||||
|
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||||
|
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
||||||
|
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||||
|
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||||
|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||||
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||||
|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||||
|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||||
|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 创新成果
|
||||||
|
|
||||||
|
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>
|
||||||
|
|
||||||
|
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||||
|
](https://arxiv.org/abs/2602.03208)
|
||||||
|
- 代码样例:coming soon
|
||||||
|
|
||||||
|
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>VIRAL:基于DiT模型的类比视觉上下文推理</summary>
|
||||||
|
|
||||||
|
- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
|
||||||
|
](https://arxiv.org/abs/2602.03210)
|
||||||
|
- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|
||||||
|
|
||||||
|
|Example 1|Example 2|Query|Output|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
|
||||||
|
|
||||||
|
- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
|
||||||
|
](https://arxiv.org/abs/2508.02151)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|
||||||
|
|
||||||
|
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
||||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
|
||||||
|
|
||||||
|
- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
|
||||||
|
](https://arxiv.org/abs/2508.02107)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
|
||||||
|
|
||||||
|
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |||||
|
||||||
|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |||||
|
||||||
|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |||||
|
||||||
|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
|
||||||
|
|
||||||
|
- 详细页面:https://github.com/modelscope/Nexus-Gen
|
||||||
|
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||||
|
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||||
|
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>ArtAug: 图像生成模型的美学提升</summary>
|
||||||
|
|
||||||
|
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
|
||||||
|
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||||
|
|
||||||
|
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>EliGen: 精准的图像分区控制</summary>
|
||||||
|
|
||||||
|
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
|实体控制区域|生成图像|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>ExVideo: 视频生成模型的扩展训练</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
|
||||||
|
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
|
||||||
|
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
|
||||||
|
|
||||||
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>DiffSynth: 本项目的初代版本</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||||
|
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
|
||||||
|
|
||||||
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
|
</details>
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
import gradio as gr
|
|
||||||
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
|
||||||
import os, torch
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model_config": {
|
|
||||||
"Stable Diffusion": {
|
|
||||||
"model_folder": "models/stable_diffusion",
|
|
||||||
"pipeline_class": SDImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion 3": {
|
|
||||||
"model_folder": "models/stable_diffusion_3",
|
|
||||||
"pipeline_class": SD3ImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL Turbo": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"negative_prompt": "",
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
"num_inference_steps": 1,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Kolors": {
|
|
||||||
"model_folder": "models/kolors",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"HunyuanDiT": {
|
|
||||||
"model_folder": "models/HunyuanDiT",
|
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"max_num_painter_layers": 8,
|
|
||||||
"max_num_model_cache": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
|
||||||
if model_type is None:
|
|
||||||
return []
|
|
||||||
folder = config["model_config"][model_type]["model_folder"]
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
|
||||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type, model_path):
|
|
||||||
global model_dict
|
|
||||||
model_key = f"{model_type}:{model_path}"
|
|
||||||
if model_key in model_dict:
|
|
||||||
return model_dict[model_key]
|
|
||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
||||||
model_manager = ModelManager()
|
|
||||||
if model_type == "HunyuanDiT":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
|
||||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
|
||||||
])
|
|
||||||
elif model_type == "Kolors":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "text_encoder"),
|
|
||||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
])
|
|
||||||
elif model_type == "FLUX":
|
|
||||||
model_manager.torch_dtype = torch.bfloat16
|
|
||||||
file_list = [
|
|
||||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
|
||||||
os.path.join(model_path, "text_encoder_2"),
|
|
||||||
]
|
|
||||||
for file_name in os.listdir(model_path):
|
|
||||||
if file_name.endswith(".safetensors"):
|
|
||||||
file_list.append(os.path.join(model_path, file_name))
|
|
||||||
model_manager.load_models(file_list)
|
|
||||||
else:
|
|
||||||
model_manager.load_model(model_path)
|
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
|
||||||
key = next(iter(model_dict.keys()))
|
|
||||||
model_manager_to_release, _ = model_dict[key]
|
|
||||||
model_manager_to_release.to("cpu")
|
|
||||||
del model_dict[key]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model_dict[model_key] = model_manager, pipe
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
|
||||||
gr.Markdown("# DiffSynth-Studio Painter")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=382, min_width=100):
|
|
||||||
|
|
||||||
with gr.Accordion(label="Model"):
|
|
||||||
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
|
||||||
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
|
||||||
|
|
||||||
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
|
||||||
def model_type_to_model_path(model_type):
|
|
||||||
return gr.Dropdown(choices=load_model_list(model_type))
|
|
||||||
|
|
||||||
with gr.Accordion(label="Prompt"):
|
|
||||||
prompt = gr.Textbox(label="Prompt", lines=3)
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
||||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
|
||||||
|
|
||||||
with gr.Accordion(label="Image"):
|
|
||||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
|
||||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
||||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
||||||
with gr.Column():
|
|
||||||
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
|
||||||
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
|
||||||
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
|
||||||
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
|
||||||
triggers=model_path.change
|
|
||||||
)
|
|
||||||
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
|
||||||
load_model(model_type, model_path)
|
|
||||||
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
|
||||||
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
|
||||||
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
|
||||||
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
|
||||||
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
|
||||||
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Column(scale=618, min_width=100):
|
|
||||||
with gr.Accordion(label="Painter"):
|
|
||||||
enable_local_prompt_list = []
|
|
||||||
local_prompt_list = []
|
|
||||||
mask_scale_list = []
|
|
||||||
canvas_list = []
|
|
||||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
||||||
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
|
||||||
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
||||||
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
|
||||||
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
|
||||||
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
|
||||||
label="Painter", key=f"canvas_{painter_layer_id}")
|
|
||||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
|
||||||
def resize_canvas(height, width, canvas):
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
if h != height or width != w:
|
|
||||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
||||||
else:
|
|
||||||
return canvas
|
|
||||||
|
|
||||||
enable_local_prompt_list.append(enable_local_prompt)
|
|
||||||
local_prompt_list.append(local_prompt)
|
|
||||||
mask_scale_list.append(mask_scale)
|
|
||||||
canvas_list.append(canvas)
|
|
||||||
with gr.Accordion(label="Results"):
|
|
||||||
run_button = gr.Button(value="Generate", variant="primary")
|
|
||||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
||||||
with gr.Column():
|
|
||||||
output_to_input_button = gr.Button(value="Set as input image")
|
|
||||||
painter_background = gr.State(None)
|
|
||||||
input_background = gr.State(None)
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
|
||||||
outputs=[output_image],
|
|
||||||
triggers=run_button.click
|
|
||||||
)
|
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
|
||||||
input_params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"progress_bar_cmd": progress.tqdm,
|
|
||||||
}
|
|
||||||
if isinstance(pipe, FluxImagePipeline):
|
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
||||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
||||||
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
|
||||||
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
|
||||||
)
|
|
||||||
local_prompts, masks, mask_scales = [], [], []
|
|
||||||
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
|
||||||
):
|
|
||||||
if enable_local_prompt:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
||||||
mask_scales.append(mask_scale)
|
|
||||||
input_params.update({
|
|
||||||
"local_prompts": local_prompts,
|
|
||||||
"masks": masks,
|
|
||||||
"mask_scales": mask_scales,
|
|
||||||
})
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
image = pipe(**input_params)
|
|
||||||
return image
|
|
||||||
|
|
||||||
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
||||||
def send_output_to_painter_background(output_image, *canvas_list):
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = output_image.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
app.launch()
|
|
||||||
@@ -1,390 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
import gradio as gr
|
|
||||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
|
||||||
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
|
||||||
with open(example_json, 'r') as f:
|
|
||||||
examples = json.load(f)['examples']
|
|
||||||
|
|
||||||
for idx in range(len(examples)):
|
|
||||||
example_id = examples[idx]['example_id']
|
|
||||||
entity_prompts = examples[idx]['local_prompt_list']
|
|
||||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
|
||||||
|
|
||||||
def create_canvas_data(background, masks):
|
|
||||||
if background.shape[-1] == 3:
|
|
||||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
|
||||||
layers = []
|
|
||||||
for mask in masks:
|
|
||||||
if mask is not None:
|
|
||||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
|
||||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
|
||||||
layer[..., -1] = mask_single_channel
|
|
||||||
layers.append(layer)
|
|
||||||
else:
|
|
||||||
layers.append(np.zeros_like(background))
|
|
||||||
|
|
||||||
composite = background.copy()
|
|
||||||
for layer in layers:
|
|
||||||
if layer.size > 0:
|
|
||||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
|
||||||
return {
|
|
||||||
"background": background,
|
|
||||||
"layers": layers,
|
|
||||||
"composite": composite,
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_example(load_example_button):
|
|
||||||
example_idx = int(load_example_button.split()[-1]) - 1
|
|
||||||
example = examples[example_idx]
|
|
||||||
result = [
|
|
||||||
50,
|
|
||||||
example["global_prompt"],
|
|
||||||
example["negative_prompt"],
|
|
||||||
example["seed"],
|
|
||||||
*example["local_prompt_list"],
|
|
||||||
]
|
|
||||||
num_entities = len(example["local_prompt_list"])
|
|
||||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
|
||||||
masks = []
|
|
||||||
for mask in example["mask_lists"]:
|
|
||||||
mask_single_channel = np.array(mask.convert("L"))
|
|
||||||
masks.append(mask_single_channel)
|
|
||||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
|
||||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
|
||||||
masks.append(blank_mask)
|
|
||||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
|
||||||
canvas_data_list = []
|
|
||||||
for mask in masks:
|
|
||||||
canvas_data = create_canvas_data(background, [mask])
|
|
||||||
canvas_data_list.append(canvas_data)
|
|
||||||
result.extend(canvas_data_list)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
|
||||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
|
||||||
print(f'save to {save_dir}')
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
for i, mask in enumerate(masks):
|
|
||||||
save_path = os.path.join(save_dir, f'{i}.png')
|
|
||||||
mask.save(save_path)
|
|
||||||
sample = {
|
|
||||||
"global_prompt": global_prompt,
|
|
||||||
"mask_prompts": mask_prompts,
|
|
||||||
"seed": seed,
|
|
||||||
}
|
|
||||||
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
|
||||||
json.dump(sample, f, indent=4)
|
|
||||||
|
|
||||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
|
||||||
# Create a blank image for overlays
|
|
||||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
|
||||||
colors = [
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
]
|
|
||||||
# Generate random colors for each mask
|
|
||||||
if use_random_colors:
|
|
||||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
|
||||||
# Font settings
|
|
||||||
try:
|
|
||||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
|
||||||
except IOError:
|
|
||||||
font = ImageFont.load_default(font_size)
|
|
||||||
# Overlay each mask onto the overlay image
|
|
||||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
|
||||||
if mask is None:
|
|
||||||
continue
|
|
||||||
# Convert mask to RGBA mode
|
|
||||||
mask_rgba = mask.convert('RGBA')
|
|
||||||
mask_data = mask_rgba.getdata()
|
|
||||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
|
||||||
mask_rgba.putdata(new_data)
|
|
||||||
# Draw the mask prompt text on the mask
|
|
||||||
draw = ImageDraw.Draw(mask_rgba)
|
|
||||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
|
||||||
if mask_bbox is None:
|
|
||||||
continue
|
|
||||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
|
||||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
|
||||||
# Alpha composite the overlay with this mask
|
|
||||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
|
||||||
# Composite the overlay onto the original image
|
|
||||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
|
||||||
return result
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model_config": {
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 3.0,
|
|
||||||
"embedded_guidance": 3.5,
|
|
||||||
"num_inference_steps": 30,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"max_num_painter_layers": 8,
|
|
||||||
"max_num_model_cache": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
|
||||||
global model_dict
|
|
||||||
model_key = f"{model_type}:{model_path}"
|
|
||||||
if model_key in model_dict:
|
|
||||||
return model_dict[model_key]
|
|
||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
|
||||||
model_manager.load_lora(
|
|
||||||
download_customized_models(
|
|
||||||
model_id="DiffSynth-Studio/Eligen",
|
|
||||||
origin_file_path="model_bf16.safetensors",
|
|
||||||
local_dir="models/lora/entity_control",
|
|
||||||
),
|
|
||||||
lora_alpha=1,
|
|
||||||
)
|
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
model_dict[model_key] = model_manager, pipe
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
|
||||||
gr.Markdown(
|
|
||||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
|
||||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
|
||||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
|
||||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
|
||||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
|
||||||
main_interface = gr.Column(visible=False)
|
|
||||||
|
|
||||||
def initialize_model():
|
|
||||||
try:
|
|
||||||
load_model()
|
|
||||||
return {
|
|
||||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
|
||||||
main_interface: gr.update(visible=True),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Failed to load model with error: {e}')
|
|
||||||
return {
|
|
||||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
|
||||||
main_interface: gr.update(visible=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
|
||||||
|
|
||||||
with main_interface:
|
|
||||||
with gr.Row():
|
|
||||||
local_prompt_list = []
|
|
||||||
canvas_list = []
|
|
||||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
|
||||||
with gr.Column(scale=382, min_width=100):
|
|
||||||
model_type = gr.State('FLUX')
|
|
||||||
model_path = gr.State('FLUX.1-dev')
|
|
||||||
with gr.Accordion(label="Global prompt"):
|
|
||||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
|
||||||
with gr.Accordion(label="Inference Options", open=True):
|
|
||||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
|
||||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
|
||||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
||||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
|
||||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
||||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
||||||
with gr.Accordion(label="Inpaint Input Image", open=False):
|
|
||||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
|
||||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
|
||||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
|
||||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
|
||||||
def reset_input_image(input_image):
|
|
||||||
return None
|
|
||||||
|
|
||||||
with gr.Column(scale=618, min_width=100):
|
|
||||||
with gr.Accordion(label="Entity Painter"):
|
|
||||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
||||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
||||||
canvas = gr.ImageEditor(
|
|
||||||
canvas_size=(512, 512),
|
|
||||||
sources=None,
|
|
||||||
layers=False,
|
|
||||||
interactive=True,
|
|
||||||
image_mode="RGBA",
|
|
||||||
brush=gr.Brush(
|
|
||||||
default_size=50,
|
|
||||||
default_color="#000000",
|
|
||||||
colors=["#000000"],
|
|
||||||
),
|
|
||||||
label="Entity Mask Painter",
|
|
||||||
key=f"canvas_{painter_layer_id}",
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
)
|
|
||||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
|
||||||
def resize_canvas(height, width, canvas):
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
if h != height or width != w:
|
|
||||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
||||||
else:
|
|
||||||
return canvas
|
|
||||||
local_prompt_list.append(local_prompt)
|
|
||||||
canvas_list.append(canvas)
|
|
||||||
with gr.Accordion(label="Results"):
|
|
||||||
run_button = gr.Button(value="Generate", variant="primary")
|
|
||||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
||||||
with gr.Column():
|
|
||||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
|
||||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
|
||||||
real_output = gr.State(None)
|
|
||||||
mask_out = gr.State(None)
|
|
||||||
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
|
||||||
outputs=[output_image, real_output, mask_out],
|
|
||||||
triggers=run_button.click
|
|
||||||
)
|
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
|
||||||
input_params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"progress_bar_cmd": progress.tqdm,
|
|
||||||
}
|
|
||||||
if isinstance(pipe, FluxImagePipeline):
|
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
|
||||||
if input_image is not None:
|
|
||||||
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
|
||||||
input_params["enable_eligen_inpaint"] = True
|
|
||||||
|
|
||||||
local_prompt_list, canvas_list = (
|
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
||||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
||||||
)
|
|
||||||
local_prompts, masks = [], []
|
|
||||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
|
||||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
||||||
entity_masks = None if len(masks) == 0 else masks
|
|
||||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
|
||||||
input_params.update({
|
|
||||||
"eligen_entity_prompts": entity_prompts,
|
|
||||||
"eligen_entity_masks": entity_masks,
|
|
||||||
})
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
|
||||||
image = pipe(**input_params)
|
|
||||||
masks = [mask.resize(image.size) for mask in masks]
|
|
||||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
|
||||||
|
|
||||||
real_output = gr.State(image)
|
|
||||||
mask_out = gr.State(image_with_mask)
|
|
||||||
|
|
||||||
if return_with_mask:
|
|
||||||
return image_with_mask, real_output, mask_out
|
|
||||||
return image, real_output, mask_out
|
|
||||||
|
|
||||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
|
||||||
def send_input_to_painter_background(input_image, *canvas_list):
|
|
||||||
if input_image is None:
|
|
||||||
return tuple(canvas_list)
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = input_image.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
||||||
def send_output_to_painter_background(real_output, *canvas_list):
|
|
||||||
if real_output is None:
|
|
||||||
return tuple(canvas_list)
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = real_output.value.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
|
||||||
def show_output(return_with_mask, real_output, mask_out):
|
|
||||||
if return_with_mask:
|
|
||||||
return mask_out.value
|
|
||||||
else:
|
|
||||||
return real_output.value
|
|
||||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
|
||||||
def send_output_to_pipe_input(real_output):
|
|
||||||
return real_output.value
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("## Examples")
|
|
||||||
for i in range(0, len(examples), 2):
|
|
||||||
with gr.Row():
|
|
||||||
if i < len(examples):
|
|
||||||
example = examples[i]
|
|
||||||
with gr.Column():
|
|
||||||
example_image = gr.Image(
|
|
||||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
||||||
label=example["description"],
|
|
||||||
interactive=False,
|
|
||||||
width=1024,
|
|
||||||
height=512
|
|
||||||
)
|
|
||||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
||||||
load_example_button.click(
|
|
||||||
load_example,
|
|
||||||
inputs=[load_example_button],
|
|
||||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
||||||
)
|
|
||||||
|
|
||||||
if i + 1 < len(examples):
|
|
||||||
example = examples[i + 1]
|
|
||||||
with gr.Column():
|
|
||||||
example_image = gr.Image(
|
|
||||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
||||||
label=example["description"],
|
|
||||||
interactive=False,
|
|
||||||
width=1024,
|
|
||||||
height=512
|
|
||||||
)
|
|
||||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
||||||
load_example_button.click(
|
|
||||||
load_example,
|
|
||||||
inputs=[load_example_button],
|
|
||||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
||||||
)
|
|
||||||
app.config["show_progress"] = "hidden"
|
|
||||||
app.launch()
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# Set web page format
|
|
||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
# Disable virtual VRAM on windows system
|
|
||||||
import torch
|
|
||||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
|
||||||
|
|
||||||
|
|
||||||
st.markdown("""
|
|
||||||
# DiffSynth Studio
|
|
||||||
|
|
||||||
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
|
||||||
|
|
||||||
Welcome to DiffSynth Studio.
|
|
||||||
""")
|
|
||||||
@@ -1,362 +0,0 @@
|
|||||||
import torch, os, io, json, time
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
from streamlit_drawable_canvas import st_canvas
|
|
||||||
from diffsynth.models import ModelManager
|
|
||||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
|
||||||
from diffsynth.data.video import crop_and_resize
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"Stable Diffusion": {
|
|
||||||
"model_folder": "models/stable_diffusion",
|
|
||||||
"pipeline_class": SDImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"Stable Diffusion 3": {
|
|
||||||
"model_folder": "models/stable_diffusion_3",
|
|
||||||
"pipeline_class": SD3ImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL Turbo": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"fixed_parameters": {
|
|
||||||
"negative_prompt": "",
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
"num_inference_steps": 1,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Kolors": {
|
|
||||||
"model_folder": "models/kolors",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"fixed_parameters": {}
|
|
||||||
},
|
|
||||||
"HunyuanDiT": {
|
|
||||||
"model_folder": "models/HunyuanDiT",
|
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
|
||||||
"fixed_parameters": {
|
|
||||||
"height": 1024,
|
|
||||||
"width": 1024,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"fixed_parameters": {
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
|
||||||
folder = config[model_type]["model_folder"]
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
|
||||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def release_model():
|
|
||||||
if "model_manager" in st.session_state:
|
|
||||||
st.session_state["model_manager"].to("cpu")
|
|
||||||
del st.session_state["loaded_model_path"]
|
|
||||||
del st.session_state["model_manager"]
|
|
||||||
del st.session_state["pipeline"]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type, model_path):
|
|
||||||
model_manager = ModelManager()
|
|
||||||
if model_type == "HunyuanDiT":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
|
||||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
|
||||||
])
|
|
||||||
elif model_type == "Kolors":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "text_encoder"),
|
|
||||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
])
|
|
||||||
elif model_type == "FLUX":
|
|
||||||
model_manager.torch_dtype = torch.bfloat16
|
|
||||||
file_list = [
|
|
||||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
|
||||||
os.path.join(model_path, "text_encoder_2"),
|
|
||||||
]
|
|
||||||
for file_name in os.listdir(model_path):
|
|
||||||
if file_name.endswith(".safetensors"):
|
|
||||||
file_list.append(os.path.join(model_path, file_name))
|
|
||||||
model_manager.load_models(file_list)
|
|
||||||
else:
|
|
||||||
model_manager.load_model(model_path)
|
|
||||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
st.session_state.loaded_model_path = model_path
|
|
||||||
st.session_state.model_manager = model_manager
|
|
||||||
st.session_state.pipeline = pipeline
|
|
||||||
return model_manager, pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def use_output_image_as_input(update=True):
|
|
||||||
# Search for input image
|
|
||||||
output_image_id = 0
|
|
||||||
selected_output_image = None
|
|
||||||
while True:
|
|
||||||
if f"use_output_as_input_{output_image_id}" not in st.session_state:
|
|
||||||
break
|
|
||||||
if st.session_state[f"use_output_as_input_{output_image_id}"]:
|
|
||||||
selected_output_image = st.session_state["output_images"][output_image_id]
|
|
||||||
break
|
|
||||||
output_image_id += 1
|
|
||||||
if update and selected_output_image is not None:
|
|
||||||
st.session_state["input_image"] = selected_output_image
|
|
||||||
return selected_output_image is not None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_stroke_to_image(stroke_image, image):
|
|
||||||
image = np.array(image.convert("RGB")).astype(np.float32)
|
|
||||||
height, width, _ = image.shape
|
|
||||||
|
|
||||||
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
|
|
||||||
weight = stroke_image[:, :, -1:] / 255
|
|
||||||
stroke_image = stroke_image[:, :, :-1]
|
|
||||||
|
|
||||||
image = stroke_image * weight + image * (1 - weight)
|
|
||||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
|
||||||
image = Image.fromarray(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
|
||||||
def image2bits(image):
|
|
||||||
image_byte = io.BytesIO()
|
|
||||||
image.save(image_byte, format="PNG")
|
|
||||||
image_byte = image_byte.getvalue()
|
|
||||||
return image_byte
|
|
||||||
|
|
||||||
|
|
||||||
def show_output_image(image):
|
|
||||||
st.image(image, use_column_width="always")
|
|
||||||
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
|
|
||||||
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
|
|
||||||
|
|
||||||
|
|
||||||
column_input, column_output = st.columns(2)
|
|
||||||
with st.sidebar:
|
|
||||||
# Select a model
|
|
||||||
with st.expander("Model", expanded=True):
|
|
||||||
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
|
|
||||||
fixed_parameters = config[model_type]["fixed_parameters"]
|
|
||||||
model_path_list = ["None"] + load_model_list(model_type)
|
|
||||||
model_path = st.selectbox("Model path", model_path_list)
|
|
||||||
|
|
||||||
# Load the model
|
|
||||||
if model_path == "None":
|
|
||||||
# No models are selected. Release VRAM.
|
|
||||||
st.markdown("No models are selected.")
|
|
||||||
release_model()
|
|
||||||
else:
|
|
||||||
# A model is selected.
|
|
||||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
|
||||||
if st.session_state.get("loaded_model_path", "") != model_path:
|
|
||||||
# The loaded model is not the selected model. Reload it.
|
|
||||||
st.markdown(f"Loading model at {model_path}.")
|
|
||||||
st.markdown("Please wait a moment...")
|
|
||||||
release_model()
|
|
||||||
model_manager, pipeline = load_model(model_type, model_path)
|
|
||||||
st.markdown("Done.")
|
|
||||||
else:
|
|
||||||
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
|
||||||
st.markdown(f"Loading model at {model_path}.")
|
|
||||||
st.markdown("Please wait a moment...")
|
|
||||||
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
|
||||||
st.markdown("Done.")
|
|
||||||
|
|
||||||
# Show parameters
|
|
||||||
with st.expander("Prompt", expanded=True):
|
|
||||||
prompt = st.text_area("Positive prompt")
|
|
||||||
if "negative_prompt" in fixed_parameters:
|
|
||||||
negative_prompt = fixed_parameters["negative_prompt"]
|
|
||||||
else:
|
|
||||||
negative_prompt = st.text_area("Negative prompt")
|
|
||||||
if "cfg_scale" in fixed_parameters:
|
|
||||||
cfg_scale = fixed_parameters["cfg_scale"]
|
|
||||||
else:
|
|
||||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
|
|
||||||
with st.expander("Image", expanded=True):
|
|
||||||
if "num_inference_steps" in fixed_parameters:
|
|
||||||
num_inference_steps = fixed_parameters["num_inference_steps"]
|
|
||||||
else:
|
|
||||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
|
|
||||||
if "height" in fixed_parameters:
|
|
||||||
height = fixed_parameters["height"]
|
|
||||||
else:
|
|
||||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
|
|
||||||
if "width" in fixed_parameters:
|
|
||||||
width = fixed_parameters["width"]
|
|
||||||
else:
|
|
||||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
|
|
||||||
num_images = st.number_input("Number of images", value=2)
|
|
||||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
|
||||||
if use_fixed_seed:
|
|
||||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
|
||||||
|
|
||||||
# Other fixed parameters
|
|
||||||
denoising_strength = 1.0
|
|
||||||
repetition = 1
|
|
||||||
|
|
||||||
|
|
||||||
# Show input image
|
|
||||||
with column_input:
|
|
||||||
with st.expander("Input image (Optional)", expanded=True):
|
|
||||||
with st.container(border=True):
|
|
||||||
column_white_board, column_upload_image = st.columns([1, 2])
|
|
||||||
with column_white_board:
|
|
||||||
create_white_board = st.button("Create white board")
|
|
||||||
delete_input_image = st.button("Delete input image")
|
|
||||||
with column_upload_image:
|
|
||||||
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
|
|
||||||
|
|
||||||
if upload_image is not None:
|
|
||||||
st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
|
|
||||||
elif create_white_board:
|
|
||||||
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
|
|
||||||
else:
|
|
||||||
use_output_image_as_input()
|
|
||||||
|
|
||||||
if delete_input_image and "input_image" in st.session_state:
|
|
||||||
del st.session_state.input_image
|
|
||||||
if delete_input_image and "upload_image" in st.session_state:
|
|
||||||
del st.session_state.upload_image
|
|
||||||
|
|
||||||
input_image = st.session_state.get("input_image", None)
|
|
||||||
if input_image is not None:
|
|
||||||
with st.container(border=True):
|
|
||||||
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
|
|
||||||
with column_drawing_mode:
|
|
||||||
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
|
|
||||||
with column_color_1:
|
|
||||||
stroke_color = st.color_picker("Stroke color")
|
|
||||||
with column_color_2:
|
|
||||||
fill_color = st.color_picker("Fill color")
|
|
||||||
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
|
|
||||||
with st.container(border=True):
|
|
||||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
|
|
||||||
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
|
|
||||||
with st.container(border=True):
|
|
||||||
input_width, input_height = input_image.size
|
|
||||||
canvas_result = st_canvas(
|
|
||||||
fill_color=fill_color,
|
|
||||||
stroke_width=stroke_width,
|
|
||||||
stroke_color=stroke_color,
|
|
||||||
background_color="rgba(255, 255, 255, 0)",
|
|
||||||
background_image=input_image,
|
|
||||||
update_streamlit=True,
|
|
||||||
height=int(512 / input_width * input_height),
|
|
||||||
width=512,
|
|
||||||
drawing_mode=drawing_mode,
|
|
||||||
key="canvas"
|
|
||||||
)
|
|
||||||
|
|
||||||
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
|
||||||
local_prompts, masks, mask_scales = [], [], []
|
|
||||||
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
|
||||||
painter_layers_json_data = []
|
|
||||||
for painter_tab_id in range(num_painter_layer):
|
|
||||||
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
|
||||||
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
|
||||||
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
|
||||||
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
|
||||||
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
|
||||||
canvas_result_local = st_canvas(
|
|
||||||
fill_color="#000000",
|
|
||||||
stroke_width=stroke_width,
|
|
||||||
stroke_color="#000000",
|
|
||||||
background_color="rgba(255, 255, 255, 0)",
|
|
||||||
background_image=white_board,
|
|
||||||
update_streamlit=True,
|
|
||||||
height=512,
|
|
||||||
width=512,
|
|
||||||
drawing_mode="freedraw",
|
|
||||||
key=f"canvas_{painter_tab_id}"
|
|
||||||
)
|
|
||||||
if canvas_result_local.json_data is not None:
|
|
||||||
painter_layers_json_data.append(canvas_result_local.json_data.copy())
|
|
||||||
painter_layers_json_data[-1]["prompt"] = local_prompt
|
|
||||||
if enable_local_prompt:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
if canvas_result_local.image_data is not None:
|
|
||||||
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
|
||||||
else:
|
|
||||||
mask = white_board
|
|
||||||
mask = Image.fromarray(255 - np.array(mask))
|
|
||||||
masks.append(mask)
|
|
||||||
mask_scales.append(mask_scale)
|
|
||||||
save_painter_layers = st.button("Save painter layers")
|
|
||||||
if save_painter_layers:
|
|
||||||
os.makedirs("data/painter_layers", exist_ok=True)
|
|
||||||
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
|
|
||||||
with open(json_file_path, "w") as f:
|
|
||||||
json.dump(painter_layers_json_data, f, indent=4)
|
|
||||||
st.markdown(f"Painter layers are saved in {json_file_path}.")
|
|
||||||
|
|
||||||
|
|
||||||
with column_output:
|
|
||||||
run_button = st.button("Generate image", type="primary")
|
|
||||||
auto_update = st.checkbox("Auto update", value=False)
|
|
||||||
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
|
|
||||||
image_columns = st.columns(num_image_columns)
|
|
||||||
|
|
||||||
# Run
|
|
||||||
if (run_button or auto_update) and model_path != "None":
|
|
||||||
|
|
||||||
if input_image is not None:
|
|
||||||
input_image = input_image.resize((width, height))
|
|
||||||
if canvas_result.image_data is not None:
|
|
||||||
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
|
|
||||||
|
|
||||||
output_images = []
|
|
||||||
for image_id in range(num_images * repetition):
|
|
||||||
if use_fixed_seed:
|
|
||||||
torch.manual_seed(seed + image_id)
|
|
||||||
else:
|
|
||||||
torch.manual_seed(np.random.randint(0, 10**9))
|
|
||||||
if image_id >= num_images:
|
|
||||||
input_image = output_images[image_id - num_images]
|
|
||||||
with image_columns[image_id % num_image_columns]:
|
|
||||||
progress_bar_st = st.progress(0.0)
|
|
||||||
image = pipeline(
|
|
||||||
prompt, negative_prompt=negative_prompt,
|
|
||||||
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
|
||||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
|
||||||
height=height, width=width,
|
|
||||||
input_image=input_image, denoising_strength=denoising_strength,
|
|
||||||
progress_bar_st=progress_bar_st
|
|
||||||
)
|
|
||||||
output_images.append(image)
|
|
||||||
progress_bar_st.progress(1.0)
|
|
||||||
show_output_image(image)
|
|
||||||
st.session_state["output_images"] = output_images
|
|
||||||
|
|
||||||
elif "output_images" in st.session_state:
|
|
||||||
for image_id in range(len(st.session_state.output_images)):
|
|
||||||
with image_columns[image_id % num_image_columns]:
|
|
||||||
image = st.session_state.output_images[image_id]
|
|
||||||
progress_bar = st.progress(1.0)
|
|
||||||
show_output_image(image)
|
|
||||||
if "upload_image" in st.session_state and use_output_image_as_input(update=False):
|
|
||||||
st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")
|
|
||||||
@@ -1,197 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
from diffsynth import SDVideoPipelineRunner
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(folder):
|
|
||||||
file_list = os.listdir(folder)
|
|
||||||
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def match_processor_id(model_name, supported_processor_id_list):
|
|
||||||
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
|
|
||||||
for processor_id in sorted_processor_id:
|
|
||||||
if processor_id in model_name:
|
|
||||||
return supported_processor_id_list.index(processor_id) + 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"models": {
|
|
||||||
"model_list": [],
|
|
||||||
"textual_inversion_folder": "models/textual_inversion",
|
|
||||||
"device": "cuda",
|
|
||||||
"lora_alphas": [],
|
|
||||||
"controlnet_units": []
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"input_frames": None,
|
|
||||||
"controlnet_frames": [],
|
|
||||||
"output_folder": "output",
|
|
||||||
"fps": 60
|
|
||||||
},
|
|
||||||
"pipeline": {
|
|
||||||
"seed": 0,
|
|
||||||
"pipeline_inputs": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
with st.expander("Model", expanded=True):
|
|
||||||
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
|
|
||||||
if stable_diffusion_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
|
|
||||||
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
|
||||||
if animatediff_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
|
||||||
column_lora, column_lora_alpha = st.columns([2, 1])
|
|
||||||
with column_lora:
|
|
||||||
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
|
|
||||||
with column_lora_alpha:
|
|
||||||
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
|
|
||||||
if sd_lora_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
|
|
||||||
config["models"]["lora_alphas"].append(lora_alpha)
|
|
||||||
|
|
||||||
|
|
||||||
with st.expander("Data", expanded=True):
|
|
||||||
with st.container(border=True):
|
|
||||||
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
|
|
||||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
|
||||||
with column_height:
|
|
||||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
|
||||||
with column_width:
|
|
||||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
|
||||||
with column_start_frame_index:
|
|
||||||
start_frame_id = st.number_input("Start Frame id", value=0)
|
|
||||||
with column_end_frame_index:
|
|
||||||
end_frame_id = st.number_input("End Frame id", value=16)
|
|
||||||
if input_video != "":
|
|
||||||
config["data"]["input_frames"] = {
|
|
||||||
"video_file": input_video,
|
|
||||||
"image_folder": None,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"start_frame_id": start_frame_id,
|
|
||||||
"end_frame_id": end_frame_id
|
|
||||||
}
|
|
||||||
with st.container(border=True):
|
|
||||||
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
|
|
||||||
fps = st.number_input("FPS", value=60)
|
|
||||||
config["data"]["output_folder"] = output_video
|
|
||||||
config["data"]["fps"] = fps
|
|
||||||
|
|
||||||
|
|
||||||
with st.expander("ControlNet Units", expanded=True):
|
|
||||||
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
|
|
||||||
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
|
|
||||||
for controlnet_id in range(len(controlnet_units)):
|
|
||||||
with controlnet_units[controlnet_id]:
|
|
||||||
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
|
|
||||||
key=f"controlnet_ckpt_{controlnet_id}")
|
|
||||||
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
|
|
||||||
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
|
|
||||||
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
|
|
||||||
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
|
|
||||||
disabled=controlnet_ckpt == "None",
|
|
||||||
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
|
|
||||||
if not use_input_video_as_controlnet_input:
|
|
||||||
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
|
|
||||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
|
||||||
with column_height:
|
|
||||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
|
|
||||||
with column_width:
|
|
||||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
|
|
||||||
with column_start_frame_index:
|
|
||||||
start_frame_id = st.number_input("Start Frame id", value=0,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
|
|
||||||
with column_end_frame_index:
|
|
||||||
end_frame_id = st.number_input("End Frame id", value=16,
|
|
||||||
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
|
|
||||||
if input_video != "":
|
|
||||||
config["data"]["input_video"] = {
|
|
||||||
"video_file": input_video,
|
|
||||||
"image_folder": None,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"start_frame_id": start_frame_id,
|
|
||||||
"end_frame_id": end_frame_id
|
|
||||||
}
|
|
||||||
if controlnet_ckpt != "None":
|
|
||||||
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
|
|
||||||
config["models"]["controlnet_units"].append({
|
|
||||||
"processor_id": processor_id,
|
|
||||||
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
|
|
||||||
"scale": controlnet_scale,
|
|
||||||
})
|
|
||||||
if use_input_video_as_controlnet_input:
|
|
||||||
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
|
|
||||||
else:
|
|
||||||
config["data"]["controlnet_frames"].append({
|
|
||||||
"video_file": input_video,
|
|
||||||
"image_folder": None,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"start_frame_id": start_frame_id,
|
|
||||||
"end_frame_id": end_frame_id
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
with st.container(border=True):
|
|
||||||
with st.expander("Seed", expanded=True):
|
|
||||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
|
||||||
if use_fixed_seed:
|
|
||||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
|
||||||
else:
|
|
||||||
seed = np.random.randint(0, 10**9)
|
|
||||||
with st.expander("Textual Guidance", expanded=True):
|
|
||||||
prompt = st.text_area("Positive prompt")
|
|
||||||
negative_prompt = st.text_area("Negative prompt")
|
|
||||||
column_cfg_scale, column_clip_skip = st.columns(2)
|
|
||||||
with column_cfg_scale:
|
|
||||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
|
|
||||||
with column_clip_skip:
|
|
||||||
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
|
|
||||||
with st.expander("Denoising", expanded=True):
|
|
||||||
column_num_inference_steps, column_denoising_strength = st.columns(2)
|
|
||||||
with column_num_inference_steps:
|
|
||||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
|
|
||||||
with column_denoising_strength:
|
|
||||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
|
|
||||||
with st.expander("Efficiency", expanded=False):
|
|
||||||
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
|
|
||||||
animatediff_stride = st.slider("Animatediff stride",
|
|
||||||
min_value=1,
|
|
||||||
max_value=max(2, animatediff_batch_size),
|
|
||||||
value=max(1, animatediff_batch_size // 2),
|
|
||||||
step=1)
|
|
||||||
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
|
|
||||||
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
|
|
||||||
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
|
|
||||||
config["pipeline"]["seed"] = seed
|
|
||||||
config["pipeline"]["pipeline_inputs"] = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"clip_skip": clip_skip,
|
|
||||||
"denoising_strength": denoising_strength,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"animatediff_batch_size": animatediff_batch_size,
|
|
||||||
"animatediff_stride": animatediff_stride,
|
|
||||||
"unet_batch_size": unet_batch_size,
|
|
||||||
"controlnet_batch_size": controlnet_batch_size,
|
|
||||||
"cross_frame_attention": cross_frame_attention,
|
|
||||||
}
|
|
||||||
|
|
||||||
run_button = st.button("☢️Run☢️", type="primary")
|
|
||||||
if run_button:
|
|
||||||
SDVideoPipelineRunner(in_streamlit=True).run(config)
|
|
||||||
@@ -1,6 +1 @@
|
|||||||
from .data import *
|
from .core import *
|
||||||
from .models import *
|
|
||||||
from .prompters import *
|
|
||||||
from .schedulers import *
|
|
||||||
from .pipelines import *
|
|
||||||
from .controlnets import *
|
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .model_configs import MODEL_CONFIGS
|
||||||
|
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
|
||||||
|
|||||||
@@ -1,800 +0,0 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
|
|
||||||
from ..models.sd_text_encoder import SDTextEncoder
|
|
||||||
from ..models.sd_unet import SDUNet
|
|
||||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
|
||||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
|
||||||
|
|
||||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from ..models.sdxl_unet import SDXLUNet
|
|
||||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from ..models.sd3_dit import SD3DiT
|
|
||||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_controlnet import SDControlNet
|
|
||||||
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
|
||||||
|
|
||||||
from ..models.sd_motion import SDMotionModel
|
|
||||||
from ..models.sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from ..models.svd_image_encoder import SVDImageEncoder
|
|
||||||
from ..models.svd_unet import SVDUNet
|
|
||||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from ..models.hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
from ..models.flux_dit import FluxDiT
|
|
||||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
|
||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
|
||||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
|
||||||
|
|
||||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
|
||||||
from ..models.cog_dit import CogDiT
|
|
||||||
|
|
||||||
from ..models.omnigen import OmniGenTransformer
|
|
||||||
|
|
||||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
|
||||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
|
||||||
|
|
||||||
from ..extensions.RIFE import IFNet
|
|
||||||
from ..extensions.ESRGAN import RRDBNet
|
|
||||||
|
|
||||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
|
||||||
|
|
||||||
from ..models.stepvideo_vae import StepVideoVAE
|
|
||||||
from ..models.stepvideo_dit import StepVideoModel
|
|
||||||
|
|
||||||
from ..models.wan_video_dit import WanModel
|
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
|
||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
|
||||||
from ..models.wan_video_vae import WanVideoVAE
|
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
|
||||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
|
||||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
|
||||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
|
||||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
|
||||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
|
||||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
|
||||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
|
||||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
|
||||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
|
||||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
|
||||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
|
||||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
|
||||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
|
||||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
|
||||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
|
||||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
|
||||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
|
||||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
|
||||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
|
||||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
|
||||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
|
||||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
|
||||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
|
||||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
|
||||||
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
|
||||||
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
|
||||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
|
||||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
|
||||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
|
||||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
|
||||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
|
||||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
|
||||||
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
|
||||||
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
|
||||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
|
||||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
|
||||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
|
||||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
|
||||||
]
|
|
||||||
huggingface_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
|
||||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
|
||||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
|
||||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
|
||||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
|
||||||
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
|
||||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
|
||||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
|
||||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
|
||||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
|
||||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
|
||||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
|
||||||
]
|
|
||||||
patch_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
|
||||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
|
||||||
]
|
|
||||||
|
|
||||||
preset_models_on_huggingface = {
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_union_sdxl_promax": [
|
|
||||||
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
|
||||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
|
|
||||||
# Qwen Prompt
|
|
||||||
"QwenPrompt": [
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
],
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": [
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
# Omost prompt
|
|
||||||
"OmostPrompt":[
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": [
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
# FLUX
|
|
||||||
"FLUX.1-dev": [
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
],
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
|
||||||
"file_list": [
|
|
||||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
|
||||||
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# CogVideo
|
|
||||||
"CogVideoX-5B": [
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3.5
|
|
||||||
"StableDiffusion3.5-large": [
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
preset_models_on_modelscope = {
|
|
||||||
# Hunyuan DiT
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
# Stable Video Diffusion
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# ExVideo
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"AingDiffusion_v12": [
|
|
||||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"Flat2DAnimerge_v45Sharp": [
|
|
||||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
|
||||||
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_union_sdxl_promax": [
|
|
||||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"Annotators:Depth": [
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Softedge": [
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Lineart": [
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Normal": [
|
|
||||||
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Openpose": [
|
|
||||||
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Qwen Prompt
|
|
||||||
"QwenPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/QwenPrompt/qwen2-1.5b-instruct",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Omost prompt
|
|
||||||
"OmostPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": {
|
|
||||||
"file_list": [
|
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/translator/opus-mt-zh-en",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": {
|
|
||||||
"file_list": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/kolors/Kolors/text_encoder",
|
|
||||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
# FLUX
|
|
||||||
"FLUX.1-dev": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
||||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"FLUX.1-schnell": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
||||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
|
||||||
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
|
||||||
],
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
|
||||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
|
||||||
],
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
|
||||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
|
||||||
],
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
|
||||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
|
||||||
],
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
|
||||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
|
||||||
],
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
|
||||||
"file_list": [
|
|
||||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
|
||||||
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"InfiniteYou":{
|
|
||||||
"file_list":[
|
|
||||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
|
||||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
|
||||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
],
|
|
||||||
"load_path":[
|
|
||||||
[
|
|
||||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
|
||||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
|
||||||
],
|
|
||||||
"models/InfiniteYou/image_proj_model.bin",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# ESRGAN
|
|
||||||
"ESRGAN_x4": [
|
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Omnigen
|
|
||||||
"OmniGen-v1": {
|
|
||||||
"file_list": [
|
|
||||||
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
|
||||||
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/OmniGen/OmniGen-v1/model.safetensors",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
# CogVideo
|
|
||||||
"CogVideoX-5B": {
|
|
||||||
"file_list": [
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/CogVideo/CogVideoX-5b/text_encoder",
|
|
||||||
"models/CogVideo/CogVideoX-5b/transformer",
|
|
||||||
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Stable Diffusion 3.5
|
|
||||||
"StableDiffusion3.5-large": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3.5-medium": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3.5-large-turbo": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"HunyuanVideo":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"HunyuanVideoI2V":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
|
||||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
|
||||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
|
||||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
|
||||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
|
||||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
|
||||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
|
||||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideoI2V/text_encoder_2",
|
|
||||||
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"HunyuanVideo-fp8":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
Preset_model_id: TypeAlias = Literal[
|
|
||||||
"HunyuanDiT",
|
|
||||||
"stable-video-diffusion-img2vid-xt",
|
|
||||||
"ExVideo-SVD-128f-v1",
|
|
||||||
"ExVideo-CogVideoX-LoRA-129f-v1",
|
|
||||||
"StableDiffusion_v15",
|
|
||||||
"DreamShaper_8",
|
|
||||||
"AingDiffusion_v12",
|
|
||||||
"Flat2DAnimerge_v45Sharp",
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
|
||||||
"StableDiffusionXL_v1",
|
|
||||||
"BluePencilXL_v200",
|
|
||||||
"StableDiffusionXL_Turbo",
|
|
||||||
"ControlNet_v11f1p_sd15_depth",
|
|
||||||
"ControlNet_v11p_sd15_softedge",
|
|
||||||
"ControlNet_v11f1e_sd15_tile",
|
|
||||||
"ControlNet_v11p_sd15_lineart",
|
|
||||||
"AnimateDiff_v2",
|
|
||||||
"AnimateDiff_xl_beta",
|
|
||||||
"RIFE",
|
|
||||||
"BeautifulPrompt",
|
|
||||||
"opus-mt-zh-en",
|
|
||||||
"IP-Adapter-SD",
|
|
||||||
"IP-Adapter-SDXL",
|
|
||||||
"StableDiffusion3",
|
|
||||||
"StableDiffusion3_without_T5",
|
|
||||||
"Kolors",
|
|
||||||
"SDXL-vae-fp16-fix",
|
|
||||||
"ControlNet_union_sdxl_promax",
|
|
||||||
"FLUX.1-dev",
|
|
||||||
"FLUX.1-schnell",
|
|
||||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Depth",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
|
||||||
"InfiniteYou",
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
|
||||||
"QwenPrompt",
|
|
||||||
"OmostPrompt",
|
|
||||||
"ESRGAN_x4",
|
|
||||||
"RIFE",
|
|
||||||
"OmniGen-v1",
|
|
||||||
"CogVideoX-5B",
|
|
||||||
"Annotators:Depth",
|
|
||||||
"Annotators:Softedge",
|
|
||||||
"Annotators:Lineart",
|
|
||||||
"Annotators:Normal",
|
|
||||||
"Annotators:Openpose",
|
|
||||||
"StableDiffusion3.5-large",
|
|
||||||
"StableDiffusion3.5-medium",
|
|
||||||
"HunyuanVideo",
|
|
||||||
"HunyuanVideo-fp8",
|
|
||||||
"HunyuanVideoI2V",
|
|
||||||
]
|
|
||||||
666
diffsynth/configs/model_configs.py
Normal file
666
diffsynth/configs/model_configs.py
Normal file
@@ -0,0 +1,666 @@
|
|||||||
|
qwen_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
||||||
|
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
||||||
|
"model_name": "qwen_image_dit",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8004730443f55db63092006dd9f7110e",
|
||||||
|
"model_name": "qwen_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
||||||
|
"model_name": "qwen_image_vae",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
||||||
|
"model_name": "qwen_image_blockwise_controlnet",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
||||||
|
"model_name": "qwen_image_blockwise_controlnet",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
|
"extra_kwargs": {"additional_in_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
||||||
|
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
||||||
|
"model_name": "siglip2_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
||||||
|
"model_hash": "5722b5c873720009de96422993b15682",
|
||||||
|
"model_name": "dinov3_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
||||||
|
"model_name": "qwen_image_image2lora_coarse",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
||||||
|
"model_name": "qwen_image_image2lora_fine",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
||||||
|
"model_name": "qwen_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
|
||||||
|
"model_name": "qwen_image_dit",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||||
|
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
|
||||||
|
"model_name": "qwen_image_vae",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||||
|
"extra_kwargs": {"image_channels": 4}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
wan_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
||||||
|
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
||||||
|
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
||||||
|
"model_name": "wan_video_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
||||||
|
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
||||||
|
"model_name": "wan_video_vae",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||||
|
"model_name": "wan_video_vap",
|
||||||
|
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
||||||
|
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
||||||
|
"model_name": "wan_video_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
||||||
|
"model_name": "wan_video_motion_controller",
|
||||||
|
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||||
|
"model_name": "wan_video_vace",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||||
|
"model_name": "wan_video_vace",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||||
|
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||||
|
"model_name": "wan_video_animate_adapter",
|
||||||
|
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
||||||
|
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
||||||
|
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
||||||
|
"model_name": "wan_video_vae",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
||||||
|
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
||||||
|
"model_name": "wans2v_audio_encoder",
|
||||||
|
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
flux_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||||
|
"model_hash": "a29710fea6dddb0314663ee823598e50",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Supported due to historical reasons.
|
||||||
|
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
|
"model_name": "flux_text_encoder_clip",
|
||||||
|
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
|
||||||
|
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
|
||||||
|
"model_name": "flux_text_encoder_t5",
|
||||||
|
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||||
|
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||||
|
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
|
||||||
|
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||||
|
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
||||||
|
"model_name": "flux_value_controller",
|
||||||
|
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "52357cb26250681367488a8954c271e8",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_single_blocks": 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
||||||
|
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
||||||
|
"model_name": "infiniteyou_image_projector",
|
||||||
|
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
||||||
|
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
||||||
|
"model_name": "flux_lora_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
||||||
|
"model_name": "flux_lora_patcher",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
||||||
|
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
||||||
|
"model_name": "nexus_gen_llm",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "nexus_gen_editing_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "nexus_gen_generation_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
||||||
|
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
||||||
|
"model_name": "flux_ipadapter",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
||||||
|
"model_name": "siglip_vision_model",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
|
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||||
|
"model_name": "step1x_connector",
|
||||||
|
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
|
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
||||||
|
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
flux2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
||||||
|
"model_name": "flux2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "c54288e3ee12ca215898840682337b95",
|
||||||
|
"model_name": "flux2_vae",
|
||||||
|
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "8B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
z_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
"extra_kwargs": {"siglip_feat_dim": 1152},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
||||||
|
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
||||||
|
"model_name": "siglip_vision_model_428m",
|
||||||
|
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||||
|
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||||
|
"model_name": "z_image_controlnet",
|
||||||
|
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ???
|
||||||
|
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||||
|
"model_name": "z_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 128},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "1392adecee344136041e70553f875f31",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
ltx2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
# { # not used currently
|
||||||
|
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
# "model_name": "ltx2_audio_vae_encoder",
|
||||||
|
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
# },
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||||
|
"model_name": "ltx2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||||
|
"model_name": "ltx2_latent_upsampler",
|
||||||
|
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||||
246
diffsynth/configs/vram_management_module_maps.py
Normal file
246
diffsynth/configs/vram_management_module_maps.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
flux_general_vram_config = {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
}
|
||||||
|
|
||||||
|
VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||||
|
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||||
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
|
||||||
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_dit.WanModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
|
||||||
|
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_mot.MotWanModel": {
|
||||||
|
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vace.VaceWanModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vae.WanVideoVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_dit.FluxDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
|
||||||
|
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_dit.Flux2DiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_vae.Flux2VAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_dit.ZImageDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
||||||
|
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
|
||||||
from .processors import Annotator
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from .processors import Processor_id
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetConfigUnit:
|
|
||||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
|
||||||
self.processor_id = processor_id
|
|
||||||
self.model_path = model_path
|
|
||||||
self.scale = scale
|
|
||||||
self.skip_processor = skip_processor
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetUnit:
|
|
||||||
def __init__(self, processor, model, scale=1.0):
|
|
||||||
self.processor = processor
|
|
||||||
self.model = model
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
|
|
||||||
class MultiControlNetManager:
|
|
||||||
def __init__(self, controlnet_units=[]):
|
|
||||||
self.processors = [unit.processor for unit in controlnet_units]
|
|
||||||
self.models = [unit.model for unit in controlnet_units]
|
|
||||||
self.scales = [unit.scale for unit in controlnet_units]
|
|
||||||
|
|
||||||
def cpu(self):
|
|
||||||
for model in self.models:
|
|
||||||
model.cpu()
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.models:
|
|
||||||
model.to(device)
|
|
||||||
for processor in self.processors:
|
|
||||||
processor.to(device)
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
|
||||||
if processor_id is None:
|
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
|
||||||
else:
|
|
||||||
processed_image = [self.processors[processor_id](image)]
|
|
||||||
processed_image = torch.concat([
|
|
||||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
|
||||||
for image_ in processed_image
|
|
||||||
], dim=0)
|
|
||||||
return processed_image
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sample, timestep, encoder_hidden_states, conditionings,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
|
||||||
):
|
|
||||||
res_stack = None
|
|
||||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
||||||
res_stack_ = model(
|
|
||||||
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
processor_id=processor.processor_id
|
|
||||||
)
|
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
return res_stack
|
|
||||||
|
|
||||||
|
|
||||||
class FluxMultiControlNetManager(MultiControlNetManager):
|
|
||||||
def __init__(self, controlnet_units=[]):
|
|
||||||
super().__init__(controlnet_units=controlnet_units)
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
|
||||||
if processor_id is None:
|
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
|
||||||
else:
|
|
||||||
processed_image = [self.processors[processor_id](image)]
|
|
||||||
return processed_image
|
|
||||||
|
|
||||||
def __call__(self, conditionings, **kwargs):
|
|
||||||
res_stack, single_res_stack = None, None
|
|
||||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
||||||
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
|
||||||
single_res_stack_ = [res * scale for res in single_res_stack_]
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
single_res_stack = single_res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
|
||||||
return res_stack, single_res_stack
|
|
||||||
6
diffsynth/core/__init__.py
Normal file
6
diffsynth/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .attention import *
|
||||||
|
from .data import *
|
||||||
|
from .gradient import *
|
||||||
|
from .loader import *
|
||||||
|
from .vram import *
|
||||||
|
from .device import *
|
||||||
1
diffsynth/core/attention/__init__.py
Normal file
1
diffsynth/core/attention/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .attention import attention_forward
|
||||||
121
diffsynth/core/attention/attention.py
Normal file
121
diffsynth/core/attention/attention.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import torch, os
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
FLASH_ATTN_2_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
SAGE_ATTN_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers.ops as xops
|
||||||
|
XFORMERS_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_attention_priority():
|
||||||
|
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
||||||
|
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
||||||
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
|
return "flash_attention_3"
|
||||||
|
elif FLASH_ATTN_2_AVAILABLE:
|
||||||
|
return "flash_attention_2"
|
||||||
|
elif SAGE_ATTN_AVAILABLE:
|
||||||
|
return "sage_attention"
|
||||||
|
elif XFORMERS_AVAILABLE:
|
||||||
|
return "xformers"
|
||||||
|
else:
|
||||||
|
return "torch"
|
||||||
|
|
||||||
|
|
||||||
|
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
||||||
|
dims = {} if dims is None else dims
|
||||||
|
if q_pattern != required_in_pattern:
|
||||||
|
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
if k_pattern != required_in_pattern:
|
||||||
|
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
if v_pattern != required_in_pattern:
|
||||||
|
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
||||||
|
dims = {} if dims is None else dims
|
||||||
|
if out_pattern != required_out_pattern:
|
||||||
|
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = out[0]
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = sageattn(q, k, v, sm_scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
||||||
|
if compatibility_mode or (attn_mask is not None):
|
||||||
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
||||||
|
else:
|
||||||
|
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
||||||
|
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
||||||
|
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
||||||
|
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "xformers":
|
||||||
|
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
else:
|
||||||
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
1
diffsynth/core/data/__init__.py
Normal file
1
diffsynth/core/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .unified_dataset import UnifiedDataset
|
||||||
220
diffsynth/core/data/operators.py
Normal file
220
diffsynth/core/data/operators.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
import torch, torchvision, imageio, os
|
||||||
|
import imageio.v3 as iio
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingPipeline:
|
||||||
|
def __init__(self, operators=None):
|
||||||
|
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for operator in self.operators:
|
||||||
|
data = operator(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperator:
|
||||||
|
def __call__(self, data):
|
||||||
|
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ToInt(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return int(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToFloat(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return float(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToStr(DataProcessingOperator):
|
||||||
|
def __init__(self, none_value=""):
|
||||||
|
self.none_value = none_value
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
if data is None: data = self.none_value
|
||||||
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImage(DataProcessingOperator):
|
||||||
|
def __init__(self, convert_RGB=True, convert_RGBA=False):
|
||||||
|
self.convert_RGB = convert_RGB
|
||||||
|
self.convert_RGBA = convert_RGBA
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
image = Image.open(data)
|
||||||
|
if self.convert_RGB: image = image.convert("RGB")
|
||||||
|
if self.convert_RGBA: image = image.convert("RGBA")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCropAndResize(DataProcessingOperator):
|
||||||
|
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
|
||||||
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_height_width(self, image):
|
||||||
|
if self.height is None or self.width is None:
|
||||||
|
width, height = image.size
|
||||||
|
if width * height > self.max_pixels:
|
||||||
|
scale = (width * height / self.max_pixels) ** 0.5
|
||||||
|
height, width = int(height / scale), int(width / scale)
|
||||||
|
height = height // self.height_division_factor * self.height_division_factor
|
||||||
|
width = width // self.width_division_factor * self.width_division_factor
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
def __call__(self, data: Image.Image):
|
||||||
|
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ToList(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadVideo(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, reader):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
if int(reader.count_frames()) < num_frames:
|
||||||
|
num_frames = int(reader.count_frames())
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
reader = imageio.get_reader(data)
|
||||||
|
num_frames = self.get_num_frames(reader)
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(frame_id)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class SequencialProcess(DataProcessingOperator):
|
||||||
|
def __init__(self, operator=lambda x: x):
|
||||||
|
self.operator = operator
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return [self.operator(i) for i in data]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadGIF(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, path):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
images = iio.imread(path, mode="RGB")
|
||||||
|
if len(images) < num_frames:
|
||||||
|
num_frames = len(images)
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
num_frames = self.get_num_frames(data)
|
||||||
|
frames = []
|
||||||
|
images = iio.imread(data, mode="RGB")
|
||||||
|
for img in images:
|
||||||
|
frame = Image.fromarray(img)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
if len(frames) >= num_frames:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByExtensionName(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
file_ext_name = data.split(".")[-1].lower()
|
||||||
|
for ext_names, operator in self.operator_map:
|
||||||
|
if ext_names is None or file_ext_name in ext_names:
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported file: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByType(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for dtype, operator in self.operator_map:
|
||||||
|
if dtype is None or isinstance(data, dtype):
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported data: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTorchPickle(DataProcessingOperator):
|
||||||
|
def __init__(self, map_location="cpu"):
|
||||||
|
self.map_location = map_location
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ToAbsolutePath(DataProcessingOperator):
|
||||||
|
def __init__(self, base_path=""):
|
||||||
|
self.base_path = base_path
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return os.path.join(self.base_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadAudio(DataProcessingOperator):
|
||||||
|
def __init__(self, sr=16000):
|
||||||
|
self.sr = sr
|
||||||
|
def __call__(self, data: str):
|
||||||
|
import librosa
|
||||||
|
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||||
|
return input_audio
|
||||||
116
diffsynth/core/data/unified_dataset.py
Normal file
116
diffsynth/core/data/unified_dataset.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from .operators import *
|
||||||
|
import torch, json, pandas
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_path=None, metadata_path=None,
|
||||||
|
repeat=1,
|
||||||
|
data_file_keys=tuple(),
|
||||||
|
main_data_operator=lambda x: x,
|
||||||
|
special_operator_map=None,
|
||||||
|
max_data_items=None,
|
||||||
|
):
|
||||||
|
self.base_path = base_path
|
||||||
|
self.metadata_path = metadata_path
|
||||||
|
self.repeat = repeat
|
||||||
|
self.data_file_keys = data_file_keys
|
||||||
|
self.main_data_operator = main_data_operator
|
||||||
|
self.cached_data_operator = LoadTorchPickle()
|
||||||
|
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||||
|
self.max_data_items = max_data_items
|
||||||
|
self.data = []
|
||||||
|
self.cached_data = []
|
||||||
|
self.load_from_cache = metadata_path is None
|
||||||
|
self.load_metadata(metadata_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_image_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||||
|
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_video_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||||
|
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||||
|
(("gif",), LoadGIF(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
])),
|
||||||
|
])
|
||||||
|
|
||||||
|
def search_for_cached_data_files(self, path):
|
||||||
|
for file_name in os.listdir(path):
|
||||||
|
subpath = os.path.join(path, file_name)
|
||||||
|
if os.path.isdir(subpath):
|
||||||
|
self.search_for_cached_data_files(subpath)
|
||||||
|
elif subpath.endswith(".pth"):
|
||||||
|
self.cached_data.append(subpath)
|
||||||
|
|
||||||
|
def load_metadata(self, metadata_path):
|
||||||
|
if metadata_path is None:
|
||||||
|
print("No metadata_path. Searching for cached data files.")
|
||||||
|
self.search_for_cached_data_files(self.base_path)
|
||||||
|
print(f"{len(self.cached_data)} cached data files found.")
|
||||||
|
elif metadata_path.endswith(".json"):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.data = metadata
|
||||||
|
elif metadata_path.endswith(".jsonl"):
|
||||||
|
metadata = []
|
||||||
|
with open(metadata_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
metadata.append(json.loads(line.strip()))
|
||||||
|
self.data = metadata
|
||||||
|
else:
|
||||||
|
metadata = pandas.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
if self.load_from_cache:
|
||||||
|
data = self.cached_data[data_id % len(self.cached_data)]
|
||||||
|
data = self.cached_data_operator(data)
|
||||||
|
else:
|
||||||
|
data = self.data[data_id % len(self.data)].copy()
|
||||||
|
for key in self.data_file_keys:
|
||||||
|
if key in data:
|
||||||
|
if key in self.special_operator_map:
|
||||||
|
data[key] = self.special_operator_map[key](data[key])
|
||||||
|
elif key in self.data_file_keys:
|
||||||
|
data[key] = self.main_data_operator(data[key])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.max_data_items is not None:
|
||||||
|
return self.max_data_items
|
||||||
|
elif self.load_from_cache:
|
||||||
|
return len(self.cached_data) * self.repeat
|
||||||
|
else:
|
||||||
|
return len(self.data) * self.repeat
|
||||||
|
|
||||||
|
def check_data_equal(self, data1, data2):
|
||||||
|
# Debug only
|
||||||
|
if len(data1) != len(data2):
|
||||||
|
return False
|
||||||
|
for k in data1:
|
||||||
|
if data1[k] != data2[k]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
2
diffsynth/core/device/__init__.py
Normal file
2
diffsynth/core/device/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||||
|
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||||
107
diffsynth/core/device/npu_compatible_device.py
Normal file
107
diffsynth/core/device/npu_compatible_device.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_npu_available():
|
||||||
|
return importlib.util.find_spec("torch_npu") is not None
|
||||||
|
|
||||||
|
|
||||||
|
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
|
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
torch.npu.config.allow_internal_format = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_type() -> str:
|
||||||
|
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
device = "cuda"
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
device = "npu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device() -> Any:
|
||||||
|
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
||||||
|
device_name = get_device_type()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return getattr(torch, device_name)
|
||||||
|
except AttributeError:
|
||||||
|
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
||||||
|
return torch.cuda
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_id() -> int:
|
||||||
|
"""Get current device id based on device type."""
|
||||||
|
return get_torch_device().current_device()
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name() -> str:
|
||||||
|
"""Get current device name based on device type."""
|
||||||
|
return f"{get_device_type()}:{get_device_id()}"
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize() -> None:
|
||||||
|
"""Execute torch synchronize operation."""
|
||||||
|
get_torch_device().synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache() -> None:
|
||||||
|
"""Execute torch empty cache operation."""
|
||||||
|
get_torch_device().empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_nccl_backend() -> str:
|
||||||
|
"""Return distributed communication backend type based on device type."""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
return "nccl"
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
return "hccl"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
||||||
|
|
||||||
|
|
||||||
|
def enable_high_precision_for_bf16():
|
||||||
|
"""
|
||||||
|
Set high accumulation dtype for matmul and reduction.
|
||||||
|
"""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.matmul.allow_tf32 = False
|
||||||
|
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_device_type(device):
|
||||||
|
if isinstance(device, str):
|
||||||
|
if device.startswith("cuda"):
|
||||||
|
return "cuda"
|
||||||
|
elif device.startswith("npu"):
|
||||||
|
return "npu"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
elif isinstance(device, torch.device):
|
||||||
|
return device.type
|
||||||
|
|
||||||
|
|
||||||
|
def parse_nccl_backend(device_type):
|
||||||
|
if device_type == "cuda":
|
||||||
|
return "nccl"
|
||||||
|
elif device_type == "npu":
|
||||||
|
return "hccl"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_device_type():
|
||||||
|
return get_device_type()
|
||||||
1
diffsynth/core/gradient/__init__.py
Normal file
1
diffsynth/core/gradient/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .gradient_checkpoint import gradient_checkpoint_forward
|
||||||
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs, **kwargs):
|
||||||
|
return module(*inputs, **kwargs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_checkpoint_forward(
|
||||||
|
model,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_output = model(*args, **kwargs)
|
||||||
|
return model_output
|
||||||
3
diffsynth/core/loader/__init__.py
Normal file
3
diffsynth/core/loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
||||||
|
from .model import load_model, load_model_with_disk_offload
|
||||||
|
from .config import ModelConfig
|
||||||
119
diffsynth/core/loader/config.py
Normal file
119
diffsynth/core/loader/config.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import torch, glob, os
|
||||||
|
from typing import Optional, Union, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
path: Union[str, list[str]] = None
|
||||||
|
model_id: str = None
|
||||||
|
origin_file_pattern: Union[str, list[str]] = None
|
||||||
|
download_source: str = None
|
||||||
|
local_model_path: str = None
|
||||||
|
skip_download: bool = None
|
||||||
|
offload_device: Optional[Union[str, torch.device]] = None
|
||||||
|
offload_dtype: Optional[torch.dtype] = None
|
||||||
|
onload_device: Optional[Union[str, torch.device]] = None
|
||||||
|
onload_dtype: Optional[torch.dtype] = None
|
||||||
|
preparing_device: Optional[Union[str, torch.device]] = None
|
||||||
|
preparing_dtype: Optional[torch.dtype] = None
|
||||||
|
computation_device: Optional[Union[str, torch.device]] = None
|
||||||
|
computation_dtype: Optional[torch.dtype] = None
|
||||||
|
clear_parameters: bool = False
|
||||||
|
state_dict: Dict[str, torch.Tensor] = None
|
||||||
|
|
||||||
|
def check_input(self):
|
||||||
|
if self.path is None and self.model_id is None:
|
||||||
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||||
|
|
||||||
|
def parse_original_file_pattern(self):
|
||||||
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
|
return "*"
|
||||||
|
elif self.origin_file_pattern.endswith("/"):
|
||||||
|
return self.origin_file_pattern + "*"
|
||||||
|
else:
|
||||||
|
return self.origin_file_pattern
|
||||||
|
|
||||||
|
def parse_download_source(self):
|
||||||
|
if self.download_source is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
||||||
|
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
||||||
|
else:
|
||||||
|
return "modelscope"
|
||||||
|
else:
|
||||||
|
return self.download_source
|
||||||
|
|
||||||
|
def parse_skip_download(self):
|
||||||
|
if self.skip_download is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||||
|
return True
|
||||||
|
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self.skip_download
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
origin_file_pattern = self.parse_original_file_pattern()
|
||||||
|
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||||
|
download_source = self.parse_download_source()
|
||||||
|
if download_source.lower() == "modelscope":
|
||||||
|
snapshot_download(
|
||||||
|
self.model_id,
|
||||||
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
|
allow_file_pattern=origin_file_pattern,
|
||||||
|
ignore_file_pattern=downloaded_files,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
elif download_source.lower() == "huggingface":
|
||||||
|
hf_snapshot_download(
|
||||||
|
self.model_id,
|
||||||
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
|
allow_patterns=origin_file_pattern,
|
||||||
|
ignore_patterns=downloaded_files,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
||||||
|
|
||||||
|
def require_downloading(self):
|
||||||
|
if self.path is not None:
|
||||||
|
return False
|
||||||
|
skip_download = self.parse_skip_download()
|
||||||
|
return not skip_download
|
||||||
|
|
||||||
|
def reset_local_model_path(self):
|
||||||
|
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||||
|
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
||||||
|
elif self.local_model_path is None:
|
||||||
|
self.local_model_path = "./models"
|
||||||
|
|
||||||
|
def download_if_necessary(self):
|
||||||
|
self.check_input()
|
||||||
|
self.reset_local_model_path()
|
||||||
|
if self.require_downloading():
|
||||||
|
self.download()
|
||||||
|
if self.path is None:
|
||||||
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
|
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||||
|
else:
|
||||||
|
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||||
|
if isinstance(self.path, list) and len(self.path) == 1:
|
||||||
|
self.path = self.path[0]
|
||||||
|
|
||||||
|
def vram_config(self):
|
||||||
|
return {
|
||||||
|
"offload_device": self.offload_device,
|
||||||
|
"offload_dtype": self.offload_dtype,
|
||||||
|
"onload_device": self.onload_device,
|
||||||
|
"onload_dtype": self.onload_dtype,
|
||||||
|
"preparing_device": self.preparing_device,
|
||||||
|
"preparing_dtype": self.preparing_dtype,
|
||||||
|
"computation_device": self.computation_device,
|
||||||
|
"computation_dtype": self.computation_dtype,
|
||||||
|
}
|
||||||
130
diffsynth/core/loader/file.py
Normal file
130
diffsynth/core/loader/file.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
import torch, hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for file_path_ in file_path:
|
||||||
|
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||||
|
else:
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [started]: {file_path}")
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||||
|
if pin_memory:
|
||||||
|
for i in state_dict:
|
||||||
|
state_dict[i] = state_dict[i].pin_memory()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [done]: {file_path}")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
state_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
state_dict[k] = f.get_tensor(k)
|
||||||
|
if torch_dtype is not None:
|
||||||
|
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||||
|
if len(state_dict) == 1:
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
elif "module" in state_dict:
|
||||||
|
state_dict = state_dict["module"]
|
||||||
|
elif "model_state" in state_dict:
|
||||||
|
state_dict = state_dict["model_state"]
|
||||||
|
if torch_dtype is not None:
|
||||||
|
for i in state_dict:
|
||||||
|
if isinstance(state_dict[i], torch.Tensor):
|
||||||
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value.shape)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict(file_path):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for file_path_ in file_path:
|
||||||
|
state_dict.update(load_keys_dict(file_path_))
|
||||||
|
return state_dict
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
return load_keys_dict_from_safetensors(file_path)
|
||||||
|
else:
|
||||||
|
return load_keys_dict_from_bin(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict_from_safetensors(file_path):
|
||||||
|
keys_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
keys_dict[k] = f.get_slice(k).get_shape()
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_to_keys_dict(state_dict):
|
||||||
|
keys_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
keys_dict[k] = list(v.shape)
|
||||||
|
else:
|
||||||
|
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict_from_bin(file_path):
|
||||||
|
state_dict = load_state_dict_from_bin(file_path)
|
||||||
|
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
||||||
|
else:
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
|
||||||
|
def hash_model_file(path, with_shape=True):
|
||||||
|
keys_dict = load_keys_dict(path)
|
||||||
|
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
105
diffsynth/core/loader/model.py
Normal file
105
diffsynth/core/loader/model.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from ..vram.initialization import skip_model_initialization
|
||||||
|
from ..vram.disk_map import DiskMap
|
||||||
|
from ..vram.layers import enable_vram_management
|
||||||
|
from .file import load_state_dict
|
||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
|
config = {} if config is None else config
|
||||||
|
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||||
|
model = model_class(**config)
|
||||||
|
# What is `module_map`?
|
||||||
|
# This is a module mapping table for VRAM management.
|
||||||
|
if module_map is not None:
|
||||||
|
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||||
|
device = [d for d in devices if d != "disk"][0]
|
||||||
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||||
|
dtype = [d for d in dtypes if d != "disk"][0]
|
||||||
|
if vram_config["offload_device"] != "disk":
|
||||||
|
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||||
|
else:
|
||||||
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||||
|
else:
|
||||||
|
# Why do we use `DiskMap`?
|
||||||
|
# Sometimes a model file contains multiple models,
|
||||||
|
# and DiskMap can load only the parameters of a single model,
|
||||||
|
# avoiding the need to load all parameters in the file.
|
||||||
|
if state_dict is not None:
|
||||||
|
pass
|
||||||
|
elif use_disk_map:
|
||||||
|
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(path, torch_dtype, device)
|
||||||
|
# Why do we use `state_dict_converter`?
|
||||||
|
# Some models are saved in complex formats,
|
||||||
|
# and we need to convert the state dict into the appropriate format.
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||||
|
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||||
|
# Loading them directly could lead to excessive GPU memory consumption.
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
|
_load_state_dict_into_zero3_model(model, state_dict)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
# Why do we call `to()`?
|
||||||
|
# Because some models override the behavior of `to()`,
|
||||||
|
# especially those from libraries like Transformers.
|
||||||
|
model = model.to(dtype=torch_dtype, device=device)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = [path]
|
||||||
|
config = {} if config is None else config
|
||||||
|
with skip_model_initialization():
|
||||||
|
model = model_class(**config)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch_dtype,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_init_context(torch_dtype, device):
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.modeling_utils import set_zero3_state
|
||||||
|
import deepspeed
|
||||||
|
# Why do we use "deepspeed.zero.Init"?
|
||||||
|
# Weight segmentation of the model can be performed on the CPU side
|
||||||
|
# and loading the segmented weights onto the computing card
|
||||||
|
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||||
|
else:
|
||||||
|
# Why do we use `skip_model_initialization`?
|
||||||
|
# It skips the random initialization of model parameters,
|
||||||
|
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||||
|
init_contexts = [skip_model_initialization()]
|
||||||
|
|
||||||
|
return init_contexts
|
||||||
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from ..device.npu_compatible_device import get_device_type
|
||||||
|
try:
|
||||||
|
import torch_npu
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_forward_npu(self, hidden_states):
|
||||||
|
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
||||||
|
if hidden_states.dtype != self.weight.dtype:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_forward_transformers_npu(self, hidden_states):
|
||||||
|
"npu rms fused operator for transformers"
|
||||||
|
if hidden_states.dtype != self.weight.dtype:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
||||||
|
"npu rope fused operator for Zimage"
|
||||||
|
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
||||||
|
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||||
|
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||||
|
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
||||||
2
diffsynth/core/vram/__init__.py
Normal file
2
diffsynth/core/vram/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .initialization import skip_model_initialization
|
||||||
|
from .layers import *
|
||||||
93
diffsynth/core/vram/disk_map.py
Normal file
93
diffsynth/core/vram/disk_map.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
import torch, os
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsCompatibleTensor:
|
||||||
|
def __init__(self, tensor):
|
||||||
|
self.tensor = tensor
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
return list(self.tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsCompatibleBinaryLoader:
|
||||||
|
def __init__(self, path, device):
|
||||||
|
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
||||||
|
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.state_dict.keys()
|
||||||
|
|
||||||
|
def get_tensor(self, name):
|
||||||
|
return self.state_dict[name]
|
||||||
|
|
||||||
|
def get_slice(self, name):
|
||||||
|
return SafetensorsCompatibleTensor(self.state_dict[name])
|
||||||
|
|
||||||
|
|
||||||
|
class DiskMap:
|
||||||
|
|
||||||
|
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
||||||
|
self.path = path if isinstance(path, list) else [path]
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
||||||
|
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
||||||
|
else:
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
self.files = []
|
||||||
|
self.flush_files()
|
||||||
|
self.name_map = {}
|
||||||
|
for file_id, file in enumerate(self.files):
|
||||||
|
for name in file.keys():
|
||||||
|
self.name_map[name] = file_id
|
||||||
|
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
||||||
|
|
||||||
|
def flush_files(self):
|
||||||
|
if len(self.files) == 0:
|
||||||
|
for path in self.path:
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
||||||
|
else:
|
||||||
|
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
||||||
|
else:
|
||||||
|
for i, path in enumerate(self.path):
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
||||||
|
self.num_params = 0
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||||
|
file_id = self.name_map[name]
|
||||||
|
param = self.files[file_id].get_tensor(name)
|
||||||
|
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||||
|
param = param.to(self.torch_dtype)
|
||||||
|
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
||||||
|
param = param.clone()
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
self.num_params += param.numel()
|
||||||
|
if self.num_params > self.buffer_size:
|
||||||
|
self.flush_files()
|
||||||
|
return param
|
||||||
|
|
||||||
|
def fetch_rename_dict(self, state_dict_converter):
|
||||||
|
if state_dict_converter is None:
|
||||||
|
return None
|
||||||
|
state_dict = {}
|
||||||
|
for file in self.files:
|
||||||
|
for name in file.keys():
|
||||||
|
state_dict[name] = name
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.rename_dict is not None:
|
||||||
|
return self.rename_dict.__iter__()
|
||||||
|
else:
|
||||||
|
return self.name_map.__iter__()
|
||||||
|
|
||||||
|
def __contains__(self, x):
|
||||||
|
if self.rename_dict is not None:
|
||||||
|
return x in self.rename_dict
|
||||||
|
else:
|
||||||
|
return x in self.name_map
|
||||||
21
diffsynth/core/vram/initialization.py
Normal file
21
diffsynth/core/vram/initialization.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def skip_model_initialization(device=torch.device("meta")):
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
old_register_parameter = torch.nn.Module.register_parameter
|
||||||
|
torch.nn.Module.register_parameter = register_empty_parameter
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Module.register_parameter = old_register_parameter
|
||||||
479
diffsynth/core/vram/layers.py
Normal file
479
diffsynth/core/vram/layers.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
import torch, copy
|
||||||
|
from typing import Union
|
||||||
|
from .initialization import skip_model_initialization
|
||||||
|
from .disk_map import DiskMap
|
||||||
|
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class AutoTorchModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.set_dtype_and_device(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.state = 0
|
||||||
|
self.name = ""
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
|
def set_dtype_and_device(
|
||||||
|
self,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
self.offload_dtype = offload_dtype or computation_dtype
|
||||||
|
self.offload_device = offload_device or computation_dtype
|
||||||
|
self.onload_dtype = onload_dtype or computation_dtype
|
||||||
|
self.onload_device = onload_device or computation_dtype
|
||||||
|
self.preparing_dtype = preparing_dtype or computation_dtype
|
||||||
|
self.preparing_device = preparing_device or computation_dtype
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.vram_limit = vram_limit
|
||||||
|
|
||||||
|
def cast_to(self, weight, dtype, device):
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def check_free_vram(self):
|
||||||
|
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||||
|
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||||
|
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||||
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state != 0:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state != 1:
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def param_name(self, name):
|
||||||
|
if self.name == "":
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
return self.name + "." + name
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedModule(AutoTorchModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.module = module
|
||||||
|
if offload_dtype == "disk":
|
||||||
|
self.name = name
|
||||||
|
self.disk_map = disk_map
|
||||||
|
self.required_params = [name for name, _ in self.module.named_parameters()]
|
||||||
|
self.disk_offload = True
|
||||||
|
else:
|
||||||
|
self.disk_offload = False
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||||
|
if copy_module:
|
||||||
|
module = copy.deepcopy(self.module)
|
||||||
|
else:
|
||||||
|
module = self.module
|
||||||
|
state_dict = {}
|
||||||
|
for name in self.required_params:
|
||||||
|
param = self.disk_map[self.param_name(name)]
|
||||||
|
param = param.to(dtype=torch_dtype, device=device)
|
||||||
|
state_dict[name] = param
|
||||||
|
module.load_state_dict(state_dict, assign=True)
|
||||||
|
module.to(dtype=torch_dtype, device=device)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def offload_to_disk(self, model: torch.nn.Module):
|
||||||
|
for buf in model.buffers():
|
||||||
|
# If there are some parameters are registed in buffers (not in state dict),
|
||||||
|
# We cannot offload the model.
|
||||||
|
for children in model.children():
|
||||||
|
self.offload_to_disk(children)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
model.to("meta")
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
# offload / onload / preparing -> offload
|
||||||
|
if self.state != 0:
|
||||||
|
if self.disk_offload:
|
||||||
|
self.offload_to_disk(self.module)
|
||||||
|
else:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
# offload / onload / preparing -> onload
|
||||||
|
if self.state < 1:
|
||||||
|
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||||
|
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||||
|
elif self.onload_device != "disk":
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def preparing(self):
|
||||||
|
# onload / preparing -> preparing
|
||||||
|
if self.state != 2:
|
||||||
|
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||||
|
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||||
|
elif self.preparing_device != "disk":
|
||||||
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def computation(self):
|
||||||
|
# onload / preparing -> computation (temporary)
|
||||||
|
if self.state == 2:
|
||||||
|
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||||
|
else:
|
||||||
|
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||||
|
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||||
|
module = self.module
|
||||||
|
elif self.disk_offload and device == "disk":
|
||||||
|
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||||
|
else:
|
||||||
|
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||||
|
self.preparing()
|
||||||
|
module = self.computation()
|
||||||
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.__dict__ or name == "module":
|
||||||
|
return super().__getattr__(name)
|
||||||
|
else:
|
||||||
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
module,
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
name,
|
||||||
|
disk_map,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if self.disk_offload:
|
||||||
|
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||||
|
if copy_module:
|
||||||
|
module = copy.deepcopy(self.module)
|
||||||
|
else:
|
||||||
|
module = self.module
|
||||||
|
state_dict = {}
|
||||||
|
for name in self.required_params:
|
||||||
|
param = self.disk_map[self.param_name(name)]
|
||||||
|
param = param.to(dtype=torch_dtype, device=device)
|
||||||
|
state_dict[name] = param
|
||||||
|
module.load_state_dict(state_dict, assign=True, strict=False)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def offload_to_disk(self, model: torch.nn.Module):
|
||||||
|
for name in self.required_params:
|
||||||
|
getattr(self, name).to("meta")
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
# Parameter casting is implemented in the model architecture.
|
||||||
|
return module
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.__dict__ or name == "module":
|
||||||
|
return super().__getattr__(name)
|
||||||
|
else:
|
||||||
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Linear,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with skip_model_initialization():
|
||||||
|
super().__init__(
|
||||||
|
in_features=module.in_features,
|
||||||
|
out_features=module.out_features,
|
||||||
|
bias=module.bias is not None,
|
||||||
|
)
|
||||||
|
self.set_dtype_and_device(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.weight = module.weight
|
||||||
|
self.bias = module.bias
|
||||||
|
self.state = 0
|
||||||
|
self.name = name
|
||||||
|
self.lora_A_weights = []
|
||||||
|
self.lora_B_weights = []
|
||||||
|
self.lora_merger = None
|
||||||
|
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
|
if offload_dtype == "disk":
|
||||||
|
self.disk_map = disk_map
|
||||||
|
self.disk_offload = True
|
||||||
|
else:
|
||||||
|
self.disk_offload = False
|
||||||
|
|
||||||
|
def fp8_linear(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = input.device
|
||||||
|
origin_dtype = input.dtype
|
||||||
|
origin_shape = input.shape
|
||||||
|
input = input.reshape(-1, origin_shape[-1])
|
||||||
|
|
||||||
|
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||||
|
fp8_max = 448.0
|
||||||
|
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||||
|
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||||
|
# we scale down the input by 2.0 in advance.
|
||||||
|
# This scaling will be compensated later during the final result scaling.
|
||||||
|
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||||
|
fp8_max = fp8_max / 2.0
|
||||||
|
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||||
|
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||||
|
input = input / (scale_a + 1e-8)
|
||||||
|
input = input.to(self.computation_dtype)
|
||||||
|
weight = weight.to(self.computation_dtype)
|
||||||
|
bias = bias.to(torch.bfloat16)
|
||||||
|
|
||||||
|
result = torch._scaled_mm(
|
||||||
|
input,
|
||||||
|
weight.T,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b.T,
|
||||||
|
bias=bias,
|
||||||
|
out_dtype=origin_dtype,
|
||||||
|
)
|
||||||
|
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||||
|
result = result.reshape(new_shape)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, assign=True):
|
||||||
|
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
||||||
|
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
||||||
|
if assign:
|
||||||
|
state_dict = {"weight": weight}
|
||||||
|
if bias is not None: state_dict["bias"] = bias
|
||||||
|
self.load_state_dict(state_dict, assign=True)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
# offload / onload / preparing -> offload
|
||||||
|
if self.state != 0:
|
||||||
|
if self.disk_offload:
|
||||||
|
self.to("meta")
|
||||||
|
else:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
# offload / onload / preparing -> onload
|
||||||
|
if self.state < 1:
|
||||||
|
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||||
|
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||||
|
elif self.onload_device != "disk":
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def preparing(self):
|
||||||
|
# onload / preparing -> preparing
|
||||||
|
if self.state != 2:
|
||||||
|
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||||
|
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||||
|
elif self.preparing_device != "disk":
|
||||||
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
def computation(self):
|
||||||
|
# onload / preparing -> computation (temporary)
|
||||||
|
if self.state == 2:
|
||||||
|
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||||
|
else:
|
||||||
|
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||||
|
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
elif self.disk_offload and device == "disk":
|
||||||
|
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
||||||
|
else:
|
||||||
|
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
def linear_forward(self, x, weight, bias):
|
||||||
|
if self.enable_fp8:
|
||||||
|
out = self.fp8_linear(x, weight, bias)
|
||||||
|
else:
|
||||||
|
out = torch.nn.functional.linear(x, weight, bias)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def lora_forward(self, x, out):
|
||||||
|
if self.lora_merger is None:
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
out = out + x @ lora_A.T @ lora_B.T
|
||||||
|
else:
|
||||||
|
lora_output = []
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||||
|
lora_output = torch.stack(lora_output)
|
||||||
|
out = self.lora_merger(out, lora_output)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||||
|
self.preparing()
|
||||||
|
weight, bias = self.computation()
|
||||||
|
out = self.linear_forward(x, weight, bias)
|
||||||
|
if len(self.lora_A_weights) > 0:
|
||||||
|
out = self.lora_forward(x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
||||||
|
if isinstance(model, AutoWrappedNonRecurseModule):
|
||||||
|
model = model.module
|
||||||
|
for name, module in model.named_children():
|
||||||
|
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
if isinstance(module, source_module):
|
||||||
|
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
if isinstance(module_, AutoWrappedNonRecurseModule):
|
||||||
|
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
setattr(model, name, module_)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def fill_vram_config(model, vram_config):
|
||||||
|
vram_config_ = vram_config.copy()
|
||||||
|
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
||||||
|
vram_config_["onload_device"] = vram_config["computation_device"]
|
||||||
|
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
||||||
|
vram_config_["preparing_device"] = vram_config["computation_device"]
|
||||||
|
for k in vram_config:
|
||||||
|
if vram_config[k] != vram_config_[k]:
|
||||||
|
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
||||||
|
break
|
||||||
|
return vram_config_
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
||||||
|
if isinstance(model, source_module):
|
||||||
|
vram_config = fill_vram_config(model, vram_config)
|
||||||
|
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||||
|
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
||||||
|
model.vram_management_enabled = True
|
||||||
|
return model
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .video import VideoData, save_video, save_frames
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
import torch, os, torchvision
|
|
||||||
from torchvision import transforms
|
|
||||||
import pandas as pd
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextImageDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
|
||||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
|
||||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
|
||||||
self.text = metadata["text"].to_list()
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.image_processor = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
|
||||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
|
||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
|
||||||
text = self.text[data_id]
|
|
||||||
image = Image.open(self.path[data_id]).convert("RGB")
|
|
||||||
target_height, target_width = self.height, self.width
|
|
||||||
width, height = image.size
|
|
||||||
scale = max(target_width / width, target_height / height)
|
|
||||||
shape = [round(height*scale),round(width*scale)]
|
|
||||||
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
|
||||||
image = self.image_processor(image)
|
|
||||||
return {"text": text, "image": image}
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.steps_per_epoch
|
|
||||||
6
diffsynth/diffusion/__init__.py
Normal file
6
diffsynth/diffusion/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .flow_match import FlowMatchScheduler
|
||||||
|
from .training_module import DiffusionTrainingModule
|
||||||
|
from .logger import ModelLogger
|
||||||
|
from .runner import launch_training_task, launch_data_process_task
|
||||||
|
from .parsers import *
|
||||||
|
from .loss import *
|
||||||
459
diffsynth/diffusion/base_pipeline.py
Normal file
459
diffsynth/diffusion/base_pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat, reduce
|
||||||
|
from typing import Union
|
||||||
|
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..utils.lora import GeneralLoRALoader
|
||||||
|
from ..models.model_loader import ModelPool
|
||||||
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnit:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seperate_cfg: bool = False,
|
||||||
|
take_over: bool = False,
|
||||||
|
input_params: tuple[str] = None,
|
||||||
|
output_params: tuple[str] = None,
|
||||||
|
input_params_posi: dict[str, str] = None,
|
||||||
|
input_params_nega: dict[str, str] = None,
|
||||||
|
onload_model_names: tuple[str] = None
|
||||||
|
):
|
||||||
|
self.seperate_cfg = seperate_cfg
|
||||||
|
self.take_over = take_over
|
||||||
|
self.input_params = input_params
|
||||||
|
self.output_params = output_params
|
||||||
|
self.input_params_posi = input_params_posi
|
||||||
|
self.input_params_nega = input_params_nega
|
||||||
|
self.onload_model_names = onload_model_names
|
||||||
|
|
||||||
|
def fetch_input_params(self):
|
||||||
|
params = []
|
||||||
|
if self.input_params is not None:
|
||||||
|
for param in self.input_params:
|
||||||
|
params.append(param)
|
||||||
|
if self.input_params_posi is not None:
|
||||||
|
for _, param in self.input_params_posi.items():
|
||||||
|
params.append(param)
|
||||||
|
if self.input_params_nega is not None:
|
||||||
|
for _, param in self.input_params_nega.items():
|
||||||
|
params.append(param)
|
||||||
|
params = sorted(list(set(params)))
|
||||||
|
return params
|
||||||
|
|
||||||
|
def fetch_output_params(self):
|
||||||
|
params = []
|
||||||
|
if self.output_params is not None:
|
||||||
|
for param in self.output_params:
|
||||||
|
params.append(param)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def process(self, pipe, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def post_process(self, pipe, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device=get_device_type(), torch_dtype=torch.float16,
|
||||||
|
height_division_factor=64, width_division_factor=64,
|
||||||
|
time_division_factor=None, time_division_remainder=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.device_type = parse_device_type(device)
|
||||||
|
# The following parameters are used for shape check.
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# VRAM management
|
||||||
|
self.vram_management_enabled = False
|
||||||
|
# Pipeline Unit Runner
|
||||||
|
self.unit_runner = PipelineUnitRunner()
|
||||||
|
# LoRA Loader
|
||||||
|
self.lora_loader = GeneralLoRALoader
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
if device is not None:
|
||||||
|
self.device = device
|
||||||
|
if dtype is not None:
|
||||||
|
self.torch_dtype = dtype
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def check_resize_height_width(self, height, width, num_frames=None):
|
||||||
|
# Shape check
|
||||||
|
if height % self.height_division_factor != 0:
|
||||||
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||||
|
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||||
|
if width % self.width_division_factor != 0:
|
||||||
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||||
|
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||||
|
if num_frames is None:
|
||||||
|
return height, width
|
||||||
|
else:
|
||||||
|
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||||
|
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||||
|
return height, width, num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a PIL.Image to torch.Tensor
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||||
|
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||||
|
image = image * ((max_value - min_value) / 255) + min_value
|
||||||
|
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a list of PIL.Image to torch.Tensor
|
||||||
|
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||||
|
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a torch.Tensor to PIL.Image
|
||||||
|
if pattern != "H W C":
|
||||||
|
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||||
|
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||||
|
image = image.to(device="cpu", dtype=torch.uint8)
|
||||||
|
image = Image.fromarray(image.numpy())
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a torch.Tensor to list of PIL.Image
|
||||||
|
if pattern != "T H W C":
|
||||||
|
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||||
|
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_to_device(self, model_names):
|
||||||
|
if self.vram_management_enabled:
|
||||||
|
# offload models
|
||||||
|
for name, model in self.named_children():
|
||||||
|
if name not in model_names:
|
||||||
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
if hasattr(model, "offload"):
|
||||||
|
model.offload()
|
||||||
|
else:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "offload"):
|
||||||
|
module.offload()
|
||||||
|
getattr(torch, self.device_type).empty_cache()
|
||||||
|
# onload models
|
||||||
|
for name, model in self.named_children():
|
||||||
|
if name in model_names:
|
||||||
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
if hasattr(model, "onload"):
|
||||||
|
model.onload()
|
||||||
|
else:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "onload"):
|
||||||
|
module.onload()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||||
|
# Initialize Gaussian noise
|
||||||
|
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||||
|
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||||
|
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
def get_vram(self):
|
||||||
|
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||||
|
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||||
|
|
||||||
|
def get_module(self, model, name):
|
||||||
|
if "." in name:
|
||||||
|
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
||||||
|
if name.isdigit():
|
||||||
|
return self.get_module(model[int(name)], suffix)
|
||||||
|
else:
|
||||||
|
return self.get_module(getattr(model, name), suffix)
|
||||||
|
else:
|
||||||
|
return getattr(model, name)
|
||||||
|
|
||||||
|
def freeze_except(self, model_names):
|
||||||
|
self.eval()
|
||||||
|
self.requires_grad_(False)
|
||||||
|
for name in model_names:
|
||||||
|
module = self.get_module(self, name)
|
||||||
|
if module is None:
|
||||||
|
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
||||||
|
continue
|
||||||
|
module.train()
|
||||||
|
module.requires_grad_(True)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_with_mask(self, base, addition, mask):
|
||||||
|
return base * (1 - mask) + addition * mask
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||||
|
timestep = scheduler.timesteps[progress_id]
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||||
|
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||||
|
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||||
|
return latents_next
|
||||||
|
|
||||||
|
|
||||||
|
def split_pipeline_units(self, model_names: list[str]):
|
||||||
|
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
||||||
|
|
||||||
|
|
||||||
|
def flush_vram_management_device(self, device):
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, AutoTorchModule):
|
||||||
|
module.offload_device = device
|
||||||
|
module.onload_device = device
|
||||||
|
module.preparing_device = device
|
||||||
|
module.computation_device = device
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
lora_config: Union[ModelConfig, str] = None,
|
||||||
|
alpha=1,
|
||||||
|
hotload=None,
|
||||||
|
state_dict=None,
|
||||||
|
verbose=1,
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
if isinstance(lora_config, str):
|
||||||
|
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora_config.download_if_necessary()
|
||||||
|
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora = state_dict
|
||||||
|
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
lora = lora_loader.convert_state_dict(lora)
|
||||||
|
if hotload is None:
|
||||||
|
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||||
|
if hotload:
|
||||||
|
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||||
|
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||||
|
updated_num = 0
|
||||||
|
for _, module in module.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
name = module.name
|
||||||
|
lora_a_name = f'{name}.lora_A.weight'
|
||||||
|
lora_b_name = f'{name}.lora_B.weight'
|
||||||
|
if lora_a_name in lora and lora_b_name in lora:
|
||||||
|
updated_num += 1
|
||||||
|
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||||
|
module.lora_B_weights.append(lora[lora_b_name])
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||||
|
else:
|
||||||
|
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_lora(self, verbose=1):
|
||||||
|
cleared_num = 0
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
if hasattr(module, "lora_A_weights"):
|
||||||
|
if len(module.lora_A_weights) > 0:
|
||||||
|
cleared_num += 1
|
||||||
|
module.lora_A_weights.clear()
|
||||||
|
if hasattr(module, "lora_B_weights"):
|
||||||
|
module.lora_B_weights.clear()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"{cleared_num} LoRA layers are cleared.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||||
|
model_pool = ModelPool()
|
||||||
|
for model_config in model_configs:
|
||||||
|
model_config.download_if_necessary()
|
||||||
|
vram_config = model_config.vram_config()
|
||||||
|
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
||||||
|
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
||||||
|
model_pool.auto_load_model(
|
||||||
|
model_config.path,
|
||||||
|
vram_config=vram_config,
|
||||||
|
vram_limit=vram_limit,
|
||||||
|
clear_parameters=model_config.clear_parameters,
|
||||||
|
state_dict=model_config.state_dict,
|
||||||
|
)
|
||||||
|
return model_pool
|
||||||
|
|
||||||
|
|
||||||
|
def check_vram_management_state(self):
|
||||||
|
vram_management_enabled = False
|
||||||
|
for module in self.children():
|
||||||
|
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
||||||
|
vram_management_enabled = True
|
||||||
|
return vram_management_enabled
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||||
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
|
self.clear_lora(verbose=0)
|
||||||
|
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||||
|
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
|
self.clear_lora(verbose=0)
|
||||||
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||||
|
if isinstance(noise_pred_posi, tuple):
|
||||||
|
# Separately handling different output types of latents, eg. video and audio latents.
|
||||||
|
noise_pred = tuple(
|
||||||
|
n_nega + cfg_scale * (n_posi - n_nega)
|
||||||
|
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnitGraph:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_edges(self, units: list[PipelineUnit]):
|
||||||
|
# Establish dependencies between units
|
||||||
|
# to search for subsequent related computation units.
|
||||||
|
last_compute_unit_id = {}
|
||||||
|
edges = []
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for input_param in unit.fetch_input_params():
|
||||||
|
if input_param in last_compute_unit_id:
|
||||||
|
edges.append((last_compute_unit_id[input_param], unit_id))
|
||||||
|
for output_param in unit.fetch_output_params():
|
||||||
|
last_compute_unit_id[output_param] = unit_id
|
||||||
|
return edges
|
||||||
|
|
||||||
|
def build_chains(self, units: list[PipelineUnit]):
|
||||||
|
# Establish updating chains for each variable
|
||||||
|
# to track their computation process.
|
||||||
|
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
||||||
|
params = sorted(list(set(params)))
|
||||||
|
chains = {param: [] for param in params}
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for param in unit.fetch_output_params():
|
||||||
|
chains[param].append(unit_id)
|
||||||
|
return chains
|
||||||
|
|
||||||
|
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
||||||
|
# Search for units that directly participate in the model's computation.
|
||||||
|
related_unit_ids = []
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for model_name in model_names:
|
||||||
|
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
||||||
|
related_unit_ids.append(unit_id)
|
||||||
|
break
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
||||||
|
# Search for subsequent related computation units.
|
||||||
|
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
||||||
|
while True:
|
||||||
|
neighbors = []
|
||||||
|
for source, target in edges:
|
||||||
|
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
||||||
|
neighbors.append(target)
|
||||||
|
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
||||||
|
neighbors.append(source)
|
||||||
|
neighbors = sorted(list(set(neighbors)))
|
||||||
|
if len(neighbors) == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
related_unit_ids.extend(neighbors)
|
||||||
|
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
||||||
|
# If the input parameters of this subgraph are updated outside the subgraph,
|
||||||
|
# search for the units where these updates occur.
|
||||||
|
first_compute_unit_id = {}
|
||||||
|
for unit_id in related_unit_ids:
|
||||||
|
for param in units[unit_id].fetch_input_params():
|
||||||
|
if param not in first_compute_unit_id:
|
||||||
|
first_compute_unit_id[param] = unit_id
|
||||||
|
updating_unit_ids = []
|
||||||
|
for param in first_compute_unit_id:
|
||||||
|
unit_id = first_compute_unit_id[param]
|
||||||
|
chain = chains[param]
|
||||||
|
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
||||||
|
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
||||||
|
if unit_id_ not in related_unit_ids:
|
||||||
|
updating_unit_ids.append(unit_id_)
|
||||||
|
related_unit_ids.extend(updating_unit_ids)
|
||||||
|
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
||||||
|
# Split the computation graph,
|
||||||
|
# separating all model-related computations.
|
||||||
|
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
||||||
|
edges = self.build_edges(units)
|
||||||
|
chains = self.build_chains(units)
|
||||||
|
while True:
|
||||||
|
num_related_unit_ids = len(related_unit_ids)
|
||||||
|
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
||||||
|
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
||||||
|
if len(related_unit_ids) == num_related_unit_ids:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
num_related_unit_ids = len(related_unit_ids)
|
||||||
|
related_units = [units[i] for i in related_unit_ids]
|
||||||
|
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
||||||
|
return related_units, unrelated_units
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnitRunner:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||||
|
if unit.take_over:
|
||||||
|
# Let the pipeline unit take over this function.
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||||
|
elif unit.seperate_cfg:
|
||||||
|
# Positive side
|
||||||
|
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||||
|
if unit.input_params is not None:
|
||||||
|
for name in unit.input_params:
|
||||||
|
processor_inputs[name] = inputs_shared.get(name)
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_posi.update(processor_outputs)
|
||||||
|
# Negative side
|
||||||
|
if inputs_shared["cfg_scale"] != 1:
|
||||||
|
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||||
|
if unit.input_params is not None:
|
||||||
|
for name in unit.input_params:
|
||||||
|
processor_inputs[name] = inputs_shared.get(name)
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_nega.update(processor_outputs)
|
||||||
|
else:
|
||||||
|
inputs_nega.update(processor_outputs)
|
||||||
|
else:
|
||||||
|
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_shared.update(processor_outputs)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
236
diffsynth/diffusion/flow_match.py
Normal file
236
diffsynth/diffusion/flow_match.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
import torch, math
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class FlowMatchScheduler():
|
||||||
|
|
||||||
|
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||||
|
self.set_timesteps_fn = {
|
||||||
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||||
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||||
|
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||||
|
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||||
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
|
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||||
|
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||||
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
|
self.num_train_timesteps = 1000
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.003/1.002
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 5 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
shift_terminal = 0.02
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||||
|
sigmas = 1 - (one_minus_z / scale_factor)
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
base_shift = math.log(3)
|
||||||
|
max_shift = math.log(3)
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_empirical_mu(image_seq_len, num_steps):
|
||||||
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
|
if image_seq_len > 4300:
|
||||||
|
mu = a2 * image_seq_len + b2
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
m_200 = a2 * image_seq_len + b2
|
||||||
|
m_10 = a1 * image_seq_len + b1
|
||||||
|
|
||||||
|
a = (m_200 - m_10) / 190.0
|
||||||
|
b = m_200 - 200.0 * a
|
||||||
|
mu = a * num_steps + b
|
||||||
|
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||||
|
sigma_min = 1 / num_inference_steps
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
if dynamic_shift_len is None:
|
||||||
|
# If you ask me why I set mu=0.8,
|
||||||
|
# I can only say that it yields better training results.
|
||||||
|
mu = 0.8
|
||||||
|
else:
|
||||||
|
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
if target_timesteps is not None:
|
||||||
|
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
||||||
|
for timestep in target_timesteps:
|
||||||
|
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||||
|
timesteps[timestep_id] = timestep
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
if special_case == "stage2":
|
||||||
|
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||||
|
elif special_case == "ditilled_stage1":
|
||||||
|
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||||
|
else:
|
||||||
|
dynamic_shift_len = dynamic_shift_len or 4096
|
||||||
|
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||||
|
image_seq_len=dynamic_shift_len,
|
||||||
|
base_seq_len=1024,
|
||||||
|
max_seq_len=4096,
|
||||||
|
base_shift=0.95,
|
||||||
|
max_shift=2.05,
|
||||||
|
)
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1.0 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||||
|
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
def set_training_weight(self):
|
||||||
|
steps = 1000
|
||||||
|
x = self.timesteps
|
||||||
|
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||||
|
y_shifted = y - y.min()
|
||||||
|
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||||
|
if len(self.timesteps) != 1000:
|
||||||
|
# This is an empirical formula.
|
||||||
|
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||||
|
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||||
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||||
|
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
denoising_strength=denoising_strength,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if training:
|
||||||
|
self.set_training_weight()
|
||||||
|
self.training = True
|
||||||
|
else:
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
|
sigma_ = 0
|
||||||
|
else:
|
||||||
|
sigma_ = self.sigmas[timestep_id + 1]
|
||||||
|
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
model_output = (sample - sample_stablized) / sigma
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
sample = (1 - sigma) * original_samples + sigma * noise
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
target = noise - sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||||
|
weights = self.linear_timesteps_weights[timestep_id]
|
||||||
|
return weights
|
||||||
43
diffsynth/diffusion/logger.py
Normal file
43
diffsynth/diffusion/logger.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os, torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLogger:
|
||||||
|
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||||
|
self.output_path = output_path
|
||||||
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||||
|
self.state_dict_converter = state_dict_converter
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||||
|
self.num_steps += 1
|
||||||
|
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||||
|
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, file_name)
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
126
diffsynth/diffusion/loss.py
Normal file
126
diffsynth/diffusion/loss.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from .base_pipeline import BasePipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
noise_pred = noise_pred[:, :, 1:]
|
||||||
|
training_target = training_target[:, :, 1:]
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||||
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||||
|
pipe.scheduler.training = True
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||||
|
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||||
|
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryImitationLoss(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(self, device):
|
||||||
|
import lpips # TODO: remove it
|
||||||
|
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
trajectory = [inputs_shared["latents"].clone()]
|
||||||
|
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||||
|
|
||||||
|
trajectory.append(inputs_shared["latents"].clone())
|
||||||
|
return pipe.scheduler.timesteps, trajectory
|
||||||
|
|
||||||
|
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
loss = 0
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
||||||
|
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
|
||||||
|
sigma = pipe.scheduler.sigmas[progress_id]
|
||||||
|
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
||||||
|
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
||||||
|
latents_ = trajectory_teacher[-1]
|
||||||
|
else:
|
||||||
|
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||||
|
latents_ = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
|
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
||||||
|
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
inputs_shared["latents"] = trajectory_teacher[0]
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||||
|
|
||||||
|
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
||||||
|
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
||||||
|
loss = self.loss_fn(image_pred.float(), image_real.float())
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if not self.initialized:
|
||||||
|
self.initialize(pipe.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
pipe.scheduler.set_timesteps(8)
|
||||||
|
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
||||||
|
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||||
|
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||||
|
loss = loss_1 + loss_2
|
||||||
|
return loss
|
||||||
70
diffsynth/diffusion/parsers.py
Normal file
70
diffsynth/diffusion/parsers.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||||
|
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||||
|
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||||
|
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||||
|
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_image_size_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_video_size_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||||
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_model_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||||
|
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||||
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
|
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
||||||
|
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_training_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||||
|
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||||
|
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||||
|
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_output_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||||
|
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||||
|
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_lora_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||||
|
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||||
|
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
||||||
|
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_gradient_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||||
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_general_config(parser: argparse.ArgumentParser):
|
||||||
|
parser = add_dataset_base_config(parser)
|
||||||
|
parser = add_model_config(parser)
|
||||||
|
parser = add_training_config(parser)
|
||||||
|
parser = add_output_config(parser)
|
||||||
|
parser = add_lora_config(parser)
|
||||||
|
parser = add_gradient_config(parser)
|
||||||
|
return parser
|
||||||
72
diffsynth/diffusion/runner.py
Normal file
72
diffsynth/diffusion/runner.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os, torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from .training_module import DiffusionTrainingModule
|
||||||
|
from .logger import ModelLogger
|
||||||
|
|
||||||
|
|
||||||
|
def launch_training_task(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
learning_rate: float = 1e-5,
|
||||||
|
weight_decay: float = 1e-2,
|
||||||
|
num_workers: int = 1,
|
||||||
|
save_steps: int = None,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
args = None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
learning_rate = args.learning_rate
|
||||||
|
weight_decay = args.weight_decay
|
||||||
|
num_workers = args.dataset_num_workers
|
||||||
|
save_steps = args.save_steps
|
||||||
|
num_epochs = args.num_epochs
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
|
for epoch_id in range(num_epochs):
|
||||||
|
for data in tqdm(dataloader):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if dataset.load_from_cache:
|
||||||
|
loss = model({}, inputs=data)
|
||||||
|
else:
|
||||||
|
loss = model(data)
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
|
scheduler.step()
|
||||||
|
if save_steps is None:
|
||||||
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_data_process_task(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
num_workers: int = 8,
|
||||||
|
args = None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
num_workers = args.dataset_num_workers
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
|
model, dataloader = accelerator.prepare(model, dataloader)
|
||||||
|
|
||||||
|
for data_id, data in enumerate(tqdm(dataloader)):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
with torch.no_grad():
|
||||||
|
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||||
|
data = model(data)
|
||||||
|
torch.save(data, save_path)
|
||||||
263
diffsynth/diffusion/training_module.py
Normal file
263
diffsynth/diffusion/training_module.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import torch, json, os
|
||||||
|
from ..core import ModelConfig, load_state_dict
|
||||||
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionTrainingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
for name, model in self.named_children():
|
||||||
|
model.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_modules(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
return trainable_modules
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_param_names(self):
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
return trainable_param_names
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||||
|
if lora_alpha is None:
|
||||||
|
lora_alpha = lora_rank
|
||||||
|
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||||
|
target_modules = target_modules[0]
|
||||||
|
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||||
|
model = inject_adapter_in_model(lora_config, model)
|
||||||
|
if upcast_dtype is not None:
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data = param.to(upcast_dtype)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mapping_lora_state_dict(self, state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||||
|
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||||
|
new_state_dict[key] = value
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||||
|
trainable_param_names = self.trainable_param_names()
|
||||||
|
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||||
|
if remove_prefix is not None:
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith(remove_prefix):
|
||||||
|
name = name[len(remove_prefix):]
|
||||||
|
state_dict_[name] = param
|
||||||
|
state_dict = state_dict_
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||||
|
if data is None:
|
||||||
|
return data
|
||||||
|
elif isinstance(data, torch.Tensor):
|
||||||
|
data = data.to(device)
|
||||||
|
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||||
|
data = data.to(torch_float_dtype)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, tuple):
|
||||||
|
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, list):
|
||||||
|
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
||||||
|
if fp8:
|
||||||
|
return {
|
||||||
|
"offload_dtype": torch.float8_e4m3fn,
|
||||||
|
"offload_device": device,
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": device,
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
elif offload:
|
||||||
|
return {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
"clear_parameters": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
||||||
|
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
||||||
|
offload_models = [] if offload_models is None else offload_models.split(",")
|
||||||
|
model_configs = []
|
||||||
|
if model_paths is not None:
|
||||||
|
model_paths = json.loads(model_paths)
|
||||||
|
for path in model_paths:
|
||||||
|
vram_config = self.parse_vram_config(
|
||||||
|
fp8=path in fp8_models,
|
||||||
|
offload=path in offload_models,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
model_configs.append(ModelConfig(path=path, **vram_config))
|
||||||
|
if model_id_with_origin_paths is not None:
|
||||||
|
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||||
|
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||||
|
vram_config = self.parse_vram_config(
|
||||||
|
fp8=model_id_with_origin_path in fp8_models,
|
||||||
|
offload=model_id_with_origin_path in offload_models,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
||||||
|
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
||||||
|
return model_configs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
||||||
|
if model_id_with_origin_path is None:
|
||||||
|
return default_value
|
||||||
|
elif os.path.exists(model_id_with_origin_path):
|
||||||
|
return ModelConfig(path=model_id_with_origin_path)
|
||||||
|
else:
|
||||||
|
if ":" not in model_id_with_origin_path:
|
||||||
|
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
||||||
|
split_id = model_id_with_origin_path.rfind(":")
|
||||||
|
model_id = model_id_with_origin_path[:split_id]
|
||||||
|
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||||
|
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_detect_lora_target_modules(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
search_for_linear=False,
|
||||||
|
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||||
|
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||||
|
name_prefix="",
|
||||||
|
):
|
||||||
|
lora_target_modules = []
|
||||||
|
if search_for_linear:
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||||
|
lora_target_modules.append(module_name)
|
||||||
|
else:
|
||||||
|
for name, module in model.named_children():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||||
|
module,
|
||||||
|
search_for_linear=block_list_detector(module),
|
||||||
|
linear_detector=linear_detector,
|
||||||
|
block_list_detector=block_list_detector,
|
||||||
|
name_prefix=module_name,
|
||||||
|
)
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||||
|
if lora_target_modules == "":
|
||||||
|
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||||
|
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||||
|
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||||
|
else:
|
||||||
|
lora_target_modules = lora_target_modules.split(",")
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def switch_pipe_to_training_mode(
|
||||||
|
self,
|
||||||
|
pipe,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
# Freeze untrainable models
|
||||||
|
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||||
|
|
||||||
|
# Preset LoRA
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
||||||
|
|
||||||
|
# FP8
|
||||||
|
# FP8 relies on a model-specific memory management scheme.
|
||||||
|
# It is delegated to the subclass.
|
||||||
|
|
||||||
|
# Add LoRA to the base models
|
||||||
|
if lora_base_model is not None and not task.endswith(":data_process"):
|
||||||
|
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
||||||
|
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
||||||
|
return
|
||||||
|
model = self.add_lora_to_model(
|
||||||
|
getattr(pipe, lora_base_model),
|
||||||
|
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
upcast_dtype=pipe.torch_dtype,
|
||||||
|
)
|
||||||
|
if lora_checkpoint is not None:
|
||||||
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
|
if len(load_result[1]) > 0:
|
||||||
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
|
||||||
|
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
||||||
|
models_require_backward = []
|
||||||
|
if trainable_models is not None:
|
||||||
|
models_require_backward += trainable_models.split(",")
|
||||||
|
if lora_base_model is not None:
|
||||||
|
models_require_backward += [lora_base_model]
|
||||||
|
if task.endswith(":data_process"):
|
||||||
|
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
elif task.endswith(":train"):
|
||||||
|
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||||
|
controlnet_keys_map = (
|
||||||
|
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
||||||
|
("controlnet_", "controlnet_inputs"),
|
||||||
|
)
|
||||||
|
controlnet_inputs = {}
|
||||||
|
for extra_input in extra_inputs:
|
||||||
|
for prefix, name in controlnet_keys_map:
|
||||||
|
if extra_input.startswith(prefix):
|
||||||
|
if name not in controlnet_inputs:
|
||||||
|
controlnet_inputs[name] = {}
|
||||||
|
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
|
for name, params in controlnet_inputs.items():
|
||||||
|
inputs_shared[name] = [ControlNetInput(**params)]
|
||||||
|
return inputs_shared
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import repeat
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_feat=64, num_grow_ch=32):
|
|
||||||
super(ResidualDenseBlock, self).__init__()
|
|
||||||
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
|
||||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_feat, num_grow_ch=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.rdb1(x)
|
|
||||||
out = self.rdb2(out)
|
|
||||||
out = self.rdb3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
|
||||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
|
||||||
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
# upsample
|
|
||||||
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
||||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
feat = x
|
|
||||||
feat = self.conv_first(feat)
|
|
||||||
body_feat = self.conv_body(self.body(feat))
|
|
||||||
feat = feat + body_feat
|
|
||||||
# upsample
|
|
||||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
||||||
feat = self.lrelu(self.conv_up1(feat))
|
|
||||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
||||||
feat = self.lrelu(self.conv_up2(feat))
|
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
|
||||||
return out
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return RRDBNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN(torch.nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
|
||||||
|
|
||||||
def process_image(self, image):
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process_images(self, images):
|
|
||||||
images = [self.process_image(image) for image in images]
|
|
||||||
images = torch.stack(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
def decode_images(self, images):
|
|
||||||
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
||||||
images = [Image.fromarray(image) for image in images]
|
|
||||||
return images
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
is_single_image = True
|
|
||||||
else:
|
|
||||||
is_single_image = False
|
|
||||||
|
|
||||||
# Preprocess
|
|
||||||
input_tensor = self.process_images(images)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(
|
|
||||||
device=self.model.conv_first.weight.device,
|
|
||||||
dtype=self.model.conv_first.weight.dtype)
|
|
||||||
batch_output_tensor = self.model(batch_input_tensor)
|
|
||||||
output_tensor.append(batch_output_tensor.cpu())
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0)
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(output_tensor)
|
|
||||||
if is_single_image:
|
|
||||||
output_images = output_images[0]
|
|
||||||
return output_images
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from .runners.fast import TableManager, PyramidPatchMatcher
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
import cupy as cp
|
|
||||||
|
|
||||||
|
|
||||||
class FastBlendSmoother:
|
|
||||||
def __init__(self):
|
|
||||||
self.batch_size = 8
|
|
||||||
self.window_size = 64
|
|
||||||
self.ebsynth_config = {
|
|
||||||
"minimum_patch_size": 5,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": 5,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": 10.0,
|
|
||||||
"initialize": "identity",
|
|
||||||
"tracking_window_size": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
# TODO: fetch GPU ID from model_manager
|
|
||||||
return FastBlendSmoother()
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
|
||||||
frames_guide = [np.array(frame) for frame in frames_guide]
|
|
||||||
frames_style = [np.array(frame) for frame in frames_style]
|
|
||||||
table_manager = TableManager()
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# left part
|
|
||||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
|
||||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
||||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
|
||||||
# right part
|
|
||||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
|
||||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
||||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
|
||||||
# merge
|
|
||||||
frames = []
|
|
||||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
||||||
weight_m = -1
|
|
||||||
weight = weight_l + weight_m + weight_r
|
|
||||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
||||||
frames.append(frame)
|
|
||||||
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
|
||||||
frames = self.run(
|
|
||||||
original_frames, rendered_frames,
|
|
||||||
self.batch_size, self.window_size, self.ebsynth_config
|
|
||||||
)
|
|
||||||
mempool = cp.get_default_memory_pool()
|
|
||||||
pinned_mempool = cp.get_default_pinned_memory_pool()
|
|
||||||
mempool.free_all_blocks()
|
|
||||||
pinned_mempool.free_all_blocks()
|
|
||||||
return frames
|
|
||||||
@@ -1,397 +0,0 @@
|
|||||||
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
|
||||||
from .data import VideoData, get_video_fps, save_video, search_for_images
|
|
||||||
import os
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
|
||||||
frames_guide = VideoData(video_guide, video_guide_folder)
|
|
||||||
frames_style = VideoData(video_style, video_style_folder)
|
|
||||||
message = ""
|
|
||||||
if len(frames_guide) < len(frames_style):
|
|
||||||
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
|
||||||
frames_style.set_length(len(frames_guide))
|
|
||||||
elif len(frames_guide) > len(frames_style):
|
|
||||||
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
|
||||||
frames_guide.set_length(len(frames_style))
|
|
||||||
height_guide, width_guide = frames_guide.shape()
|
|
||||||
height_style, width_style = frames_style.shape()
|
|
||||||
if height_guide != height_style or width_guide != width_style:
|
|
||||||
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
|
||||||
frames_style.set_shape(height_guide, width_guide)
|
|
||||||
return frames_guide, frames_style, message
|
|
||||||
|
|
||||||
|
|
||||||
def smooth_video(
|
|
||||||
video_guide,
|
|
||||||
video_guide_folder,
|
|
||||||
video_style,
|
|
||||||
video_style_folder,
|
|
||||||
mode,
|
|
||||||
window_size,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize,
|
|
||||||
progress = None,
|
|
||||||
):
|
|
||||||
# input
|
|
||||||
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
|
||||||
if len(message) > 0:
|
|
||||||
print(message)
|
|
||||||
# output
|
|
||||||
if output_path == "":
|
|
||||||
if video_style is None:
|
|
||||||
output_path = os.path.join(video_style_folder, "output")
|
|
||||||
else:
|
|
||||||
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("No valid output_path. Your video will be saved here:", output_path)
|
|
||||||
elif not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("Your video will be saved here:", output_path)
|
|
||||||
frames_path = os.path.join(output_path, "frames")
|
|
||||||
video_path = os.path.join(output_path, "video.mp4")
|
|
||||||
os.makedirs(frames_path, exist_ok=True)
|
|
||||||
# process
|
|
||||||
if mode == "Fast" or mode == "Balanced":
|
|
||||||
tracking_window_size = 0
|
|
||||||
ebsynth_config = {
|
|
||||||
"minimum_patch_size": minimum_patch_size,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": num_iter,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": guide_weight,
|
|
||||||
"initialize": initialize,
|
|
||||||
"tracking_window_size": tracking_window_size,
|
|
||||||
}
|
|
||||||
if mode == "Fast":
|
|
||||||
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
elif mode == "Balanced":
|
|
||||||
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
elif mode == "Accurate":
|
|
||||||
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
# output
|
|
||||||
try:
|
|
||||||
fps = int(fps)
|
|
||||||
except:
|
|
||||||
fps = get_video_fps(video_style) if video_style is not None else 30
|
|
||||||
print("Fps:", fps)
|
|
||||||
print("Saving video...")
|
|
||||||
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
|
||||||
print("Success!")
|
|
||||||
print("Your frames are here:", frames_path)
|
|
||||||
print("Your video is here:", video_path)
|
|
||||||
return output_path, fps, video_path
|
|
||||||
|
|
||||||
|
|
||||||
class KeyFrameMatcher:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def extract_number_from_filename(self, file_name):
|
|
||||||
result = []
|
|
||||||
number = -1
|
|
||||||
for i in file_name:
|
|
||||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
||||||
if number == -1:
|
|
||||||
number = 0
|
|
||||||
number = number*10 + ord(i) - ord("0")
|
|
||||||
else:
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
number = -1
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
result = tuple(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def extract_number_from_filenames(self, file_names):
|
|
||||||
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
|
||||||
min_length = min(len(i) for i in numbers)
|
|
||||||
for i in range(min_length-1, -1, -1):
|
|
||||||
if len(set(number[i] for number in numbers))==len(file_names):
|
|
||||||
return [number[i] for number in numbers]
|
|
||||||
return list(range(len(file_names)))
|
|
||||||
|
|
||||||
def match_using_filename(self, file_names_a, file_names_b):
|
|
||||||
file_names_b_set = set(file_names_b)
|
|
||||||
matched_file_name = []
|
|
||||||
for file_name in file_names_a:
|
|
||||||
if file_name not in file_names_b_set:
|
|
||||||
matched_file_name.append(None)
|
|
||||||
else:
|
|
||||||
matched_file_name.append(file_name)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
def match_using_numbers(self, file_names_a, file_names_b):
|
|
||||||
numbers_a = self.extract_number_from_filenames(file_names_a)
|
|
||||||
numbers_b = self.extract_number_from_filenames(file_names_b)
|
|
||||||
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
|
||||||
matched_file_name = []
|
|
||||||
for number in numbers_a:
|
|
||||||
if number in numbers_b_dict:
|
|
||||||
matched_file_name.append(numbers_b_dict[number])
|
|
||||||
else:
|
|
||||||
matched_file_name.append(None)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
def match_filenames(self, file_names_a, file_names_b):
|
|
||||||
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
|
||||||
if sum([i is not None for i in matched_file_name]) > 0:
|
|
||||||
return matched_file_name
|
|
||||||
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
|
|
||||||
def detect_frames(frames_path, keyframes_path):
|
|
||||||
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
|
||||||
return "Please input the directory of guide video and rendered frames"
|
|
||||||
elif not os.path.exists(frames_path):
|
|
||||||
return "Please input the directory of guide video"
|
|
||||||
elif not os.path.exists(keyframes_path):
|
|
||||||
return "Please input the directory of rendered frames"
|
|
||||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
|
||||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
|
||||||
if len(frames)==0:
|
|
||||||
return f"No images detected in {frames_path}"
|
|
||||||
if len(keyframes)==0:
|
|
||||||
return f"No images detected in {keyframes_path}"
|
|
||||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
|
||||||
max_filename_length = max([len(i) for i in frames])
|
|
||||||
if sum([i is not None for i in matched_keyframes])==0:
|
|
||||||
message = ""
|
|
||||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
|
||||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
|
||||||
message += "--> No matched keyframes\n"
|
|
||||||
else:
|
|
||||||
message = ""
|
|
||||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
|
||||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
|
||||||
if matched_keyframe is None:
|
|
||||||
message += "--> [to be rendered]\n"
|
|
||||||
else:
|
|
||||||
message += f"--> {matched_keyframe}\n"
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_for_interpolating(frames_path, keyframes_path):
|
|
||||||
# search for images
|
|
||||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
|
||||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
|
||||||
# match frames
|
|
||||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
|
||||||
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
|
||||||
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
|
||||||
frames_guide = VideoData(None, frames_path)
|
|
||||||
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
|
||||||
# match shape
|
|
||||||
message = ""
|
|
||||||
height_guide, width_guide = frames_guide.shape()
|
|
||||||
height_style, width_style = frames_style.shape()
|
|
||||||
if height_guide != height_style or width_guide != width_style:
|
|
||||||
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
|
||||||
frames_style.set_shape(height_guide, width_guide)
|
|
||||||
return frames_guide, frames_style, index_style, message
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate_video(
|
|
||||||
frames_path,
|
|
||||||
keyframes_path,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize,
|
|
||||||
progress = None,
|
|
||||||
):
|
|
||||||
# input
|
|
||||||
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
|
||||||
if len(message) > 0:
|
|
||||||
print(message)
|
|
||||||
# output
|
|
||||||
if output_path == "":
|
|
||||||
output_path = os.path.join(keyframes_path, "output")
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("No valid output_path. Your video will be saved here:", output_path)
|
|
||||||
elif not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("Your video will be saved here:", output_path)
|
|
||||||
output_frames_path = os.path.join(output_path, "frames")
|
|
||||||
output_video_path = os.path.join(output_path, "video.mp4")
|
|
||||||
os.makedirs(output_frames_path, exist_ok=True)
|
|
||||||
# process
|
|
||||||
ebsynth_config = {
|
|
||||||
"minimum_patch_size": minimum_patch_size,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": num_iter,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": guide_weight,
|
|
||||||
"initialize": initialize,
|
|
||||||
"tracking_window_size": tracking_window_size
|
|
||||||
}
|
|
||||||
if len(index_style)==1:
|
|
||||||
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
|
||||||
else:
|
|
||||||
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
|
||||||
try:
|
|
||||||
fps = int(fps)
|
|
||||||
except:
|
|
||||||
fps = 30
|
|
||||||
print("Fps:", fps)
|
|
||||||
print("Saving video...")
|
|
||||||
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
|
||||||
print("Success!")
|
|
||||||
print("Your frames are here:", output_frames_path)
|
|
||||||
print("Your video is here:", video_path)
|
|
||||||
return output_path, fps, video_path
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tabs():
|
|
||||||
with gr.Blocks(analytics_enabled=False) as ui_component:
|
|
||||||
with gr.Tab("Blend"):
|
|
||||||
gr.Markdown("""
|
|
||||||
# Blend
|
|
||||||
|
|
||||||
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
|
||||||
""")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab("Guide video"):
|
|
||||||
video_guide = gr.Video(label="Guide video")
|
|
||||||
with gr.Tab("Guide video (images format)"):
|
|
||||||
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab("Style video"):
|
|
||||||
video_style = gr.Video(label="Style video")
|
|
||||||
with gr.Tab("Style video (images format)"):
|
|
||||||
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
|
||||||
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
|
||||||
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
|
||||||
btn = gr.Button(value="Blend")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("# Settings")
|
|
||||||
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
|
||||||
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
|
||||||
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
|
||||||
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
|
||||||
gr.Markdown("## Advanced Settings")
|
|
||||||
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
|
||||||
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
|
||||||
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
|
||||||
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("""
|
|
||||||
# Reference
|
|
||||||
|
|
||||||
* Output directory: the directory to save the video.
|
|
||||||
* Inference mode
|
|
||||||
|
|
||||||
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
|
||||||
|-|-|-|-|-|-|
|
|
||||||
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
|
||||||
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
|
||||||
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
|
||||||
|
|
||||||
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
|
||||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
|
||||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
|
||||||
* Advanced settings
|
|
||||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
|
||||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
|
||||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
|
||||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
|
||||||
""")
|
|
||||||
btn.click(
|
|
||||||
smooth_video,
|
|
||||||
inputs=[
|
|
||||||
video_guide,
|
|
||||||
video_guide_folder,
|
|
||||||
video_style,
|
|
||||||
video_style_folder,
|
|
||||||
mode,
|
|
||||||
window_size,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize
|
|
||||||
],
|
|
||||||
outputs=[output_path, fps, video_output]
|
|
||||||
)
|
|
||||||
with gr.Tab("Interpolate"):
|
|
||||||
gr.Markdown("""
|
|
||||||
# Interpolate
|
|
||||||
|
|
||||||
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
|
||||||
""")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
|
||||||
with gr.Row():
|
|
||||||
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
|
||||||
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
|
||||||
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
|
||||||
with gr.Column():
|
|
||||||
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
|
||||||
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
|
||||||
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
|
||||||
btn_ = gr.Button(value="Interpolate")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("# Settings")
|
|
||||||
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
|
||||||
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
|
||||||
gr.Markdown("## Advanced Settings")
|
|
||||||
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
|
||||||
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
|
||||||
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
|
||||||
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("""
|
|
||||||
# Reference
|
|
||||||
|
|
||||||
* Output directory: the directory to save the video.
|
|
||||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
|
||||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
|
||||||
* Advanced settings
|
|
||||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
|
||||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
|
||||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
|
||||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
|
||||||
""")
|
|
||||||
btn_.click(
|
|
||||||
interpolate_video,
|
|
||||||
inputs=[
|
|
||||||
video_guide_folder_,
|
|
||||||
rendered_keyframes_,
|
|
||||||
output_path_,
|
|
||||||
fps_,
|
|
||||||
batch_size_,
|
|
||||||
tracking_window_size_,
|
|
||||||
minimum_patch_size_,
|
|
||||||
num_iter_,
|
|
||||||
guide_weight_,
|
|
||||||
initialize_,
|
|
||||||
],
|
|
||||||
outputs=[output_path_, fps_, video_output_]
|
|
||||||
)
|
|
||||||
|
|
||||||
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
import cupy as cp
|
|
||||||
|
|
||||||
remapping_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void remap(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source_style,
|
|
||||||
const int* nnf,
|
|
||||||
float* target_style
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
|
||||||
const int min_px = x < r ? -x : -r;
|
|
||||||
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
|
||||||
const int min_py = y < r ? -y : -r;
|
|
||||||
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
|
||||||
int num = 0;
|
|
||||||
for (int px = min_px; px <= max_px; px++){
|
|
||||||
for (int py = min_py; py <= max_py; py++){
|
|
||||||
const int nid = (x + px) * width + y + py;
|
|
||||||
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
|
||||||
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
|
||||||
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
|
||||||
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
|
||||||
num++;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
target_style[z + pid * channel + c] /= num;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
''', 'remap')
|
|
||||||
|
|
||||||
|
|
||||||
patch_error_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void patch_error(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source,
|
|
||||||
const int* nnf,
|
|
||||||
const float* target,
|
|
||||||
float* error
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
|
||||||
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
|
||||||
float e = 0;
|
|
||||||
for (int px = -r; px <= r; px++){
|
|
||||||
for (int py = -r; py <= r; py++){
|
|
||||||
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
|
||||||
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
|
||||||
e += diff * diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error[blockIdx.z * height * width + x * width + y] = e;
|
|
||||||
}
|
|
||||||
''', 'patch_error')
|
|
||||||
|
|
||||||
|
|
||||||
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void pairwise_patch_error(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source_a,
|
|
||||||
const int* nnf_a,
|
|
||||||
const float* source_b,
|
|
||||||
const int* nnf_b,
|
|
||||||
float* error
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
|
||||||
const int x_a = nnf_a[z_nnf + 0];
|
|
||||||
const int y_a = nnf_a[z_nnf + 1];
|
|
||||||
const int x_b = nnf_b[z_nnf + 0];
|
|
||||||
const int y_b = nnf_b[z_nnf + 1];
|
|
||||||
float e = 0;
|
|
||||||
for (int px = -r; px <= r; px++){
|
|
||||||
for (int py = -r; py <= r; py++){
|
|
||||||
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
|
||||||
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
|
||||||
e += diff * diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error[blockIdx.z * height * width + x * width + y] = e;
|
|
||||||
}
|
|
||||||
''', 'pairwise_patch_error')
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
import imageio, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def read_video(file_name):
|
|
||||||
reader = imageio.get_reader(file_name)
|
|
||||||
video = []
|
|
||||||
for frame in reader:
|
|
||||||
frame = np.array(frame)
|
|
||||||
video.append(frame)
|
|
||||||
reader.close()
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_fps(file_name):
|
|
||||||
reader = imageio.get_reader(file_name)
|
|
||||||
fps = reader.get_meta_data()["fps"]
|
|
||||||
reader.close()
|
|
||||||
return fps
|
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames_path, video_path, num_frames, fps):
|
|
||||||
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
|
||||||
for i in range(num_frames):
|
|
||||||
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
|
||||||
writer.append_data(frame)
|
|
||||||
writer.close()
|
|
||||||
return video_path
|
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryVideo:
|
|
||||||
def __init__(self, file_name):
|
|
||||||
self.reader = imageio.get_reader(file_name)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.reader.count_frames()
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return np.array(self.reader.get_data(item))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.reader.close()
|
|
||||||
|
|
||||||
|
|
||||||
def split_file_name(file_name):
|
|
||||||
result = []
|
|
||||||
number = -1
|
|
||||||
for i in file_name:
|
|
||||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
||||||
if number == -1:
|
|
||||||
number = 0
|
|
||||||
number = number*10 + ord(i) - ord("0")
|
|
||||||
else:
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
number = -1
|
|
||||||
result.append(i)
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
result = tuple(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_images(folder):
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
|
||||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
|
||||||
file_list = [i[1] for i in sorted(file_list)]
|
|
||||||
file_list = [os.path.join(folder, i) for i in file_list]
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def read_images(folder):
|
|
||||||
file_list = search_for_images(folder)
|
|
||||||
frames = [np.array(Image.open(i)) for i in file_list]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryImageFolder:
|
|
||||||
def __init__(self, folder, file_list=None):
|
|
||||||
if file_list is None:
|
|
||||||
self.file_list = search_for_images(folder)
|
|
||||||
else:
|
|
||||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.file_list)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return np.array(Image.open(self.file_list[item]))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VideoData:
|
|
||||||
def __init__(self, video_file, image_folder, **kwargs):
|
|
||||||
if video_file is not None:
|
|
||||||
self.data_type = "video"
|
|
||||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
|
||||||
elif image_folder is not None:
|
|
||||||
self.data_type = "images"
|
|
||||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("Cannot open video or image folder")
|
|
||||||
self.length = None
|
|
||||||
self.height = None
|
|
||||||
self.width = None
|
|
||||||
|
|
||||||
def raw_data(self):
|
|
||||||
frames = []
|
|
||||||
for i in range(self.__len__()):
|
|
||||||
frames.append(self.__getitem__(i))
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def set_length(self, length):
|
|
||||||
self.length = length
|
|
||||||
|
|
||||||
def set_shape(self, height, width):
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
if self.length is None:
|
|
||||||
return len(self.data)
|
|
||||||
else:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def shape(self):
|
|
||||||
if self.height is not None and self.width is not None:
|
|
||||||
return self.height, self.width
|
|
||||||
else:
|
|
||||||
height, width, _ = self.__getitem__(0).shape
|
|
||||||
return height, width
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
frame = self.data.__getitem__(item)
|
|
||||||
height, width, _ = frame.shape
|
|
||||||
if self.height is not None and self.width is not None:
|
|
||||||
if self.height != height or self.width != width:
|
|
||||||
frame = Image.fromarray(frame).resize((self.width, self.height))
|
|
||||||
frame = np.array(frame)
|
|
||||||
return frame
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pass
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
|
||||||
import numpy as np
|
|
||||||
import cupy as cp
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMatcher:
|
|
||||||
def __init__(
|
|
||||||
self, height, width, channel, minimum_patch_size,
|
|
||||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
||||||
random_search_steps=3, random_search_range=4,
|
|
||||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
||||||
tracking_window_size=0
|
|
||||||
):
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.channel = channel
|
|
||||||
self.minimum_patch_size = minimum_patch_size
|
|
||||||
self.threads_per_block = threads_per_block
|
|
||||||
self.num_iter = num_iter
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.guide_weight = guide_weight
|
|
||||||
self.random_search_steps = random_search_steps
|
|
||||||
self.random_search_range = random_search_range
|
|
||||||
self.use_mean_target_style = use_mean_target_style
|
|
||||||
self.use_pairwise_patch_error = use_pairwise_patch_error
|
|
||||||
self.tracking_window_size = tracking_window_size
|
|
||||||
|
|
||||||
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
|
||||||
self.pad_size = self.patch_size_list[0] // 2
|
|
||||||
self.grid = (
|
|
||||||
(height + threads_per_block - 1) // threads_per_block,
|
|
||||||
(width + threads_per_block - 1) // threads_per_block
|
|
||||||
)
|
|
||||||
self.block = (threads_per_block, threads_per_block)
|
|
||||||
|
|
||||||
def pad_image(self, image):
|
|
||||||
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
|
||||||
|
|
||||||
def unpad_image(self, image):
|
|
||||||
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
|
||||||
|
|
||||||
def apply_nnf_to_image(self, nnf, source):
|
|
||||||
batch_size = source.shape[0]
|
|
||||||
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
|
||||||
remapping_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
|
||||||
)
|
|
||||||
return target
|
|
||||||
|
|
||||||
def get_patch_error(self, source, nnf, target):
|
|
||||||
batch_size = source.shape[0]
|
|
||||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
||||||
patch_error_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
|
||||||
)
|
|
||||||
return error
|
|
||||||
|
|
||||||
def get_pairwise_patch_error(self, source, nnf):
|
|
||||||
batch_size = source.shape[0]//2
|
|
||||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
||||||
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
|
||||||
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
|
||||||
pairwise_patch_error_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
|
||||||
)
|
|
||||||
error = error.repeat(2, axis=0)
|
|
||||||
return error
|
|
||||||
|
|
||||||
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
|
||||||
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
|
||||||
if self.use_mean_target_style:
|
|
||||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
||||||
target_style = target_style.mean(axis=0, keepdims=True)
|
|
||||||
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
|
||||||
if self.use_pairwise_patch_error:
|
|
||||||
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
|
||||||
else:
|
|
||||||
error_style = self.get_patch_error(source_style, nnf, target_style)
|
|
||||||
error = error_guide * self.guide_weight + error_style
|
|
||||||
return error
|
|
||||||
|
|
||||||
def clamp_bound(self, nnf):
|
|
||||||
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
|
||||||
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def random_step(self, nnf, r):
|
|
||||||
batch_size = nnf.shape[0]
|
|
||||||
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
|
||||||
upd_nnf = self.clamp_bound(nnf + step)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def neighboor_step(self, nnf, d):
|
|
||||||
if d==0:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
|
||||||
upd_nnf[:, :, :, 0] += 1
|
|
||||||
elif d==1:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
|
||||||
upd_nnf[:, :, :, 1] += 1
|
|
||||||
elif d==2:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
|
||||||
upd_nnf[:, :, :, 0] -= 1
|
|
||||||
elif d==3:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
|
||||||
upd_nnf[:, :, :, 1] -= 1
|
|
||||||
upd_nnf = self.clamp_bound(upd_nnf)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def shift_nnf(self, nnf, d):
|
|
||||||
if d>0:
|
|
||||||
d = min(nnf.shape[0], d)
|
|
||||||
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
||||||
else:
|
|
||||||
d = max(-nnf.shape[0], d)
|
|
||||||
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def track_step(self, nnf, d):
|
|
||||||
if self.use_pairwise_patch_error:
|
|
||||||
upd_nnf = cp.zeros_like(nnf)
|
|
||||||
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
|
||||||
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
|
||||||
else:
|
|
||||||
upd_nnf = self.shift_nnf(nnf, d)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def C(self, n, m):
|
|
||||||
# not used
|
|
||||||
c = 1
|
|
||||||
for i in range(1, n+1):
|
|
||||||
c *= i
|
|
||||||
for i in range(1, m+1):
|
|
||||||
c //= i
|
|
||||||
for i in range(1, n-m+1):
|
|
||||||
c //= i
|
|
||||||
return c
|
|
||||||
|
|
||||||
def bezier_step(self, nnf, r):
|
|
||||||
# not used
|
|
||||||
n = r * 2 - 1
|
|
||||||
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
|
||||||
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
|
||||||
if d>0:
|
|
||||||
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
||||||
elif d<0:
|
|
||||||
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
||||||
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
|
||||||
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
|
||||||
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
|
||||||
upd_idx = (upd_err < err)
|
|
||||||
nnf[upd_idx] = upd_nnf[upd_idx]
|
|
||||||
err[upd_idx] = upd_err[upd_idx]
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for d in cp.random.permutation(4):
|
|
||||||
upd_nnf = self.neighboor_step(nnf, d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for i in range(self.random_search_steps):
|
|
||||||
upd_nnf = self.random_step(nnf, self.random_search_range)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for d in range(1, self.tracking_window_size + 1):
|
|
||||||
upd_nnf = self.track_step(nnf, d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
upd_nnf = self.track_step(nnf, -d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
source_guide = self.pad_image(source_guide)
|
|
||||||
target_guide = self.pad_image(target_guide)
|
|
||||||
source_style = self.pad_image(source_style)
|
|
||||||
for it in range(self.num_iter):
|
|
||||||
self.patch_size = self.patch_size_list[it]
|
|
||||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
||||||
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
|
||||||
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
|
||||||
return nnf, target_style
|
|
||||||
|
|
||||||
|
|
||||||
class PyramidPatchMatcher:
|
|
||||||
def __init__(
|
|
||||||
self, image_height, image_width, channel, minimum_patch_size,
|
|
||||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
||||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
||||||
tracking_window_size=0,
|
|
||||||
initialize="identity"
|
|
||||||
):
|
|
||||||
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
|
||||||
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
|
||||||
self.pyramid_heights = []
|
|
||||||
self.pyramid_widths = []
|
|
||||||
self.patch_matchers = []
|
|
||||||
self.minimum_patch_size = minimum_patch_size
|
|
||||||
self.num_iter = num_iter
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.initialize = initialize
|
|
||||||
for level in range(self.pyramid_level):
|
|
||||||
height = image_height//(2**(self.pyramid_level - 1 - level))
|
|
||||||
width = image_width//(2**(self.pyramid_level - 1 - level))
|
|
||||||
self.pyramid_heights.append(height)
|
|
||||||
self.pyramid_widths.append(width)
|
|
||||||
self.patch_matchers.append(PatchMatcher(
|
|
||||||
height, width, channel, minimum_patch_size=minimum_patch_size,
|
|
||||||
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
|
||||||
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
|
||||||
tracking_window_size=tracking_window_size
|
|
||||||
))
|
|
||||||
|
|
||||||
def resample_image(self, images, level):
|
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
||||||
images = images.get()
|
|
||||||
images_resample = []
|
|
||||||
for image in images:
|
|
||||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
|
||||||
images_resample.append(image_resample)
|
|
||||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
|
||||||
return images_resample
|
|
||||||
|
|
||||||
def initialize_nnf(self, batch_size):
|
|
||||||
if self.initialize == "random":
|
|
||||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
||||||
nnf = cp.stack([
|
|
||||||
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
|
||||||
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
|
||||||
], axis=3)
|
|
||||||
elif self.initialize == "identity":
|
|
||||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
||||||
nnf = cp.stack([
|
|
||||||
cp.repeat(cp.arange(height), width).reshape(height, width),
|
|
||||||
cp.tile(cp.arange(width), height).reshape(height, width)
|
|
||||||
], axis=2)
|
|
||||||
nnf = cp.stack([nnf] * batch_size)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def update_nnf(self, nnf, level):
|
|
||||||
# upscale
|
|
||||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
|
||||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
|
||||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
|
||||||
# check if scale is 2
|
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
||||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
|
||||||
nnf = nnf.get().astype(np.float32)
|
|
||||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
|
||||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
|
||||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def apply_nnf_to_image(self, nnf, image):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
image = self.patch_matchers[-1].pad_image(image)
|
|
||||||
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def estimate_nnf(self, source_guide, target_guide, source_style):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
if not isinstance(source_guide, cp.ndarray):
|
|
||||||
source_guide = cp.array(source_guide, dtype=cp.float32)
|
|
||||||
if not isinstance(target_guide, cp.ndarray):
|
|
||||||
target_guide = cp.array(target_guide, dtype=cp.float32)
|
|
||||||
if not isinstance(source_style, cp.ndarray):
|
|
||||||
source_style = cp.array(source_style, dtype=cp.float32)
|
|
||||||
for level in range(self.pyramid_level):
|
|
||||||
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
|
||||||
source_guide_ = self.resample_image(source_guide, level)
|
|
||||||
target_guide_ = self.resample_image(target_guide, level)
|
|
||||||
source_style_ = self.resample_image(source_style, level)
|
|
||||||
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
|
||||||
source_guide_, target_guide_, source_style_, nnf
|
|
||||||
)
|
|
||||||
return nnf.get(), target_style.get()
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
from .accurate import AccurateModeRunner
|
|
||||||
from .fast import FastModeRunner
|
|
||||||
from .balanced import BalancedModeRunner
|
|
||||||
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class AccurateModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
use_mean_target_style=True,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# run
|
|
||||||
n = len(frames_style)
|
|
||||||
for target in tqdm(range(n), desc=desc):
|
|
||||||
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
|
||||||
remapped_frames = []
|
|
||||||
for i in range(l, r, batch_size):
|
|
||||||
j = min(i + batch_size, r)
|
|
||||||
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
|
||||||
target_guide = np.stack([frames_guide[target]] * (j - i))
|
|
||||||
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
remapped_frames.append(target_style)
|
|
||||||
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
if save_path is not None:
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class BalancedModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# tasks
|
|
||||||
n = len(frames_style)
|
|
||||||
tasks = []
|
|
||||||
for target in range(n):
|
|
||||||
for source in range(target - window_size, target + window_size + 1):
|
|
||||||
if source >= 0 and source < n and source != target:
|
|
||||||
tasks.append((source, target))
|
|
||||||
# run
|
|
||||||
frames = [(None, 1) for i in range(n)]
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
|
||||||
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for (source, target), result in zip(tasks_batch, target_style):
|
|
||||||
frame, weight = frames[target]
|
|
||||||
if frame is None:
|
|
||||||
frame = frames_style[target]
|
|
||||||
frames[target] = (
|
|
||||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
||||||
weight + 1
|
|
||||||
)
|
|
||||||
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
if save_path is not None:
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
frames[target] = (None, 1)
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import functools, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class TableManager:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def task_list(self, n):
|
|
||||||
tasks = []
|
|
||||||
max_level = 1
|
|
||||||
while (1<<max_level)<=n:
|
|
||||||
max_level += 1
|
|
||||||
for i in range(n):
|
|
||||||
j = i
|
|
||||||
for level in range(max_level):
|
|
||||||
if i&(1<<level):
|
|
||||||
continue
|
|
||||||
j |= 1<<level
|
|
||||||
if j>=n:
|
|
||||||
break
|
|
||||||
meta_data = {
|
|
||||||
"source": i,
|
|
||||||
"target": j,
|
|
||||||
"level": level + 1
|
|
||||||
}
|
|
||||||
tasks.append(meta_data)
|
|
||||||
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
|
||||||
n = len(frames_guide)
|
|
||||||
tasks = self.task_list(n)
|
|
||||||
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
||||||
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for task, result in zip(tasks_batch, target_style):
|
|
||||||
target, level = task["target"], task["level"]
|
|
||||||
if len(remapping_table[target])==level:
|
|
||||||
remapping_table[target].append((result, 1))
|
|
||||||
else:
|
|
||||||
frame, weight = remapping_table[target][level]
|
|
||||||
remapping_table[target][level] = (
|
|
||||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
||||||
weight + 1
|
|
||||||
)
|
|
||||||
return remapping_table
|
|
||||||
|
|
||||||
def remapping_table_to_blending_table(self, table):
|
|
||||||
for i in range(len(table)):
|
|
||||||
for j in range(1, len(table[i])):
|
|
||||||
frame_1, weight_1 = table[i][j-1]
|
|
||||||
frame_2, weight_2 = table[i][j]
|
|
||||||
frame = (frame_1 + frame_2) / 2
|
|
||||||
weight = weight_1 + weight_2
|
|
||||||
table[i][j] = (frame, weight)
|
|
||||||
return table
|
|
||||||
|
|
||||||
def tree_query(self, leftbound, rightbound):
|
|
||||||
node_list = []
|
|
||||||
node_index = rightbound
|
|
||||||
while node_index>=leftbound:
|
|
||||||
node_level = 0
|
|
||||||
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
|
||||||
node_level += 1
|
|
||||||
node_list.append((node_index, node_level))
|
|
||||||
node_index -= 1<<node_level
|
|
||||||
return node_list
|
|
||||||
|
|
||||||
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
|
||||||
n = len(blending_table)
|
|
||||||
tasks = []
|
|
||||||
frames_result = []
|
|
||||||
for target in range(n):
|
|
||||||
node_list = self.tree_query(max(target-window_size, 0), target)
|
|
||||||
for source, level in node_list:
|
|
||||||
if source!=target:
|
|
||||||
meta_data = {
|
|
||||||
"source": source,
|
|
||||||
"target": target,
|
|
||||||
"level": level
|
|
||||||
}
|
|
||||||
tasks.append(meta_data)
|
|
||||||
else:
|
|
||||||
frames_result.append(blending_table[target][level])
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
||||||
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for task, frame_2 in zip(tasks_batch, target_style):
|
|
||||||
source, target, level = task["source"], task["target"], task["level"]
|
|
||||||
frame_1, weight_1 = frames_result[target]
|
|
||||||
weight_2 = blending_table[source][level][1]
|
|
||||||
weight = weight_1 + weight_2
|
|
||||||
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
|
||||||
frames_result[target] = (frame, weight)
|
|
||||||
return frames_result
|
|
||||||
|
|
||||||
|
|
||||||
class FastModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
|
||||||
frames_guide = frames_guide.raw_data()
|
|
||||||
frames_style = frames_style.raw_data()
|
|
||||||
table_manager = TableManager()
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# left part
|
|
||||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
|
||||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
||||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
|
||||||
# right part
|
|
||||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
|
||||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
||||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
|
||||||
# merge
|
|
||||||
frames = []
|
|
||||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
||||||
weight_m = -1
|
|
||||||
weight = weight_l + weight_m + weight_r
|
|
||||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
||||||
frames.append(frame)
|
|
||||||
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
|
||||||
if save_path is not None:
|
|
||||||
for target, frame in enumerate(frames):
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_index_dict(self, index_style):
|
|
||||||
index_dict = {}
|
|
||||||
for i, index in enumerate(index_style):
|
|
||||||
index_dict[index] = i
|
|
||||||
return index_dict
|
|
||||||
|
|
||||||
def get_weight(self, l, m, r):
|
|
||||||
weight_l, weight_r = abs(m - r), abs(m - l)
|
|
||||||
if weight_l + weight_r == 0:
|
|
||||||
weight_l, weight_r = 0.5, 0.5
|
|
||||||
else:
|
|
||||||
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
|
||||||
return weight_l, weight_r
|
|
||||||
|
|
||||||
def get_task_group(self, index_style, n):
|
|
||||||
task_group = []
|
|
||||||
index_style = sorted(index_style)
|
|
||||||
# first frame
|
|
||||||
if index_style[0]>0:
|
|
||||||
tasks = []
|
|
||||||
for m in range(index_style[0]):
|
|
||||||
tasks.append((index_style[0], m, index_style[0]))
|
|
||||||
task_group.append(tasks)
|
|
||||||
# middle frames
|
|
||||||
for l, r in zip(index_style[:-1], index_style[1:]):
|
|
||||||
tasks = []
|
|
||||||
for m in range(l, r):
|
|
||||||
tasks.append((l, m, r))
|
|
||||||
task_group.append(tasks)
|
|
||||||
# last frame
|
|
||||||
tasks = []
|
|
||||||
for m in range(index_style[-1], n):
|
|
||||||
tasks.append((index_style[-1], m, index_style[-1]))
|
|
||||||
task_group.append(tasks)
|
|
||||||
return task_group
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
use_mean_target_style=False,
|
|
||||||
use_pairwise_patch_error=True,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# task
|
|
||||||
index_dict = self.get_index_dict(index_style)
|
|
||||||
task_group = self.get_task_group(index_style, len(frames_guide))
|
|
||||||
# run
|
|
||||||
for tasks in task_group:
|
|
||||||
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide, target_guide, source_style = [], [], []
|
|
||||||
for l, m, r in tasks_batch:
|
|
||||||
# l -> m
|
|
||||||
source_guide.append(frames_guide[l])
|
|
||||||
target_guide.append(frames_guide[m])
|
|
||||||
source_style.append(frames_style[index_dict[l]])
|
|
||||||
# r -> m
|
|
||||||
source_guide.append(frames_guide[r])
|
|
||||||
target_guide.append(frames_guide[m])
|
|
||||||
source_style.append(frames_style[index_dict[r]])
|
|
||||||
source_guide = np.stack(source_guide)
|
|
||||||
target_guide = np.stack(target_guide)
|
|
||||||
source_style = np.stack(source_style)
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
if save_path is not None:
|
|
||||||
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
|
||||||
weight_l, weight_r = self.get_weight(l, m, r)
|
|
||||||
frame = frame_l * weight_l + frame_r * weight_r
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationModeSingleFrameRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
||||||
# check input
|
|
||||||
tracking_window_size = ebsynth_config["tracking_window_size"]
|
|
||||||
if tracking_window_size * 2 >= batch_size:
|
|
||||||
raise ValueError("batch_size should be larger than track_window_size * 2")
|
|
||||||
frame_style = frames_style[0]
|
|
||||||
frame_guide = frames_guide[index_style[0]]
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frame_style.shape[0],
|
|
||||||
image_width=frame_style.shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# run
|
|
||||||
frame_id, n = 0, len(frames_guide)
|
|
||||||
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
|
||||||
if i + batch_size > n:
|
|
||||||
l, r = max(n - batch_size, 0), n
|
|
||||||
else:
|
|
||||||
l, r = i, i + batch_size
|
|
||||||
source_guide = np.stack([frame_guide] * (r-l))
|
|
||||||
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
|
||||||
source_style = np.stack([frame_style] * (r-l))
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for i, frame in zip(range(l, r), target_style):
|
|
||||||
if i==frame_id:
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
|
||||||
frame_id += 1
|
|
||||||
if r < n and r-frame_id <= tracking_window_size:
|
|
||||||
break
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .blip_pretrain import *
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
'''
|
|
||||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
|
||||||
'''
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from timm.models.hub import download_cached_file
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
from .vit import VisionTransformer, interpolate_pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def default_bert():
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
|
||||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
|
||||||
return os.path.join(model_path, "bert-base-uncased")
|
|
||||||
|
|
||||||
|
|
||||||
def init_tokenizer(bert_model_path):
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
|
||||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
|
||||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
|
||||||
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
|
||||||
|
|
||||||
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
|
||||||
if vit=='base':
|
|
||||||
vision_width = 768
|
|
||||||
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
|
||||||
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
|
||||||
drop_path_rate=0 or drop_path_rate
|
|
||||||
)
|
|
||||||
elif vit=='large':
|
|
||||||
vision_width = 1024
|
|
||||||
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
|
||||||
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
|
||||||
drop_path_rate=0.1 or drop_path_rate
|
|
||||||
)
|
|
||||||
return visual_encoder, vision_width
|
|
||||||
|
|
||||||
|
|
||||||
def is_url(url_or_filename):
|
|
||||||
parsed = urlparse(url_or_filename)
|
|
||||||
return parsed.scheme in ("http", "https")
|
|
||||||
|
|
||||||
def load_checkpoint(model,url_or_filename):
|
|
||||||
if is_url(url_or_filename):
|
|
||||||
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
|
||||||
checkpoint = torch.load(cached_file, map_location='cpu')
|
|
||||||
elif os.path.isfile(url_or_filename):
|
|
||||||
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
|
||||||
else:
|
|
||||||
raise RuntimeError('checkpoint url or path is invalid')
|
|
||||||
|
|
||||||
state_dict = checkpoint['model']
|
|
||||||
|
|
||||||
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
|
||||||
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
|
||||||
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
|
||||||
model.visual_encoder_m)
|
|
||||||
for key in model.state_dict().keys():
|
|
||||||
if key in state_dict.keys():
|
|
||||||
if state_dict[key].shape!=model.state_dict()[key].shape:
|
|
||||||
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
|
||||||
del state_dict[key]
|
|
||||||
|
|
||||||
msg = model.load_state_dict(state_dict,strict=False)
|
|
||||||
print('load checkpoint from %s'%url_or_filename)
|
|
||||||
return model,msg
|
|
||||||
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
'''
|
|
||||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
|
||||||
'''
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
import os
|
|
||||||
from .med import BertConfig, BertModel
|
|
||||||
from .blip import create_vit, init_tokenizer
|
|
||||||
|
|
||||||
class BLIP_Pretrain(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
med_config = "med_config.json",
|
|
||||||
image_size = 224,
|
|
||||||
vit = 'base',
|
|
||||||
vit_grad_ckpt = False,
|
|
||||||
vit_ckpt_layer = 0,
|
|
||||||
embed_dim = 256,
|
|
||||||
queue_size = 57600,
|
|
||||||
momentum = 0.995,
|
|
||||||
bert_model_path = ""
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
|
||||||
image_size (int): input image size
|
|
||||||
vit (str): model size of vision transformer
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
|
||||||
|
|
||||||
self.tokenizer = init_tokenizer(bert_model_path)
|
|
||||||
encoder_config = BertConfig.from_json_file(med_config)
|
|
||||||
encoder_config.encoder_width = vision_width
|
|
||||||
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
|
||||||
|
|
||||||
text_width = self.text_encoder.config.hidden_size
|
|
||||||
|
|
||||||
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
|
||||||
self.text_proj = nn.Linear(text_width, embed_dim)
|
|
||||||
|
|
||||||
@@ -1,947 +0,0 @@
|
|||||||
'''
|
|
||||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
|
||||||
* Based on huggingface code base
|
|
||||||
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
|
||||||
'''
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, device, nn
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.file_utils import (
|
|
||||||
ModelOutput,
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
||||||
CausalLMOutputWithCrossAttentions,
|
|
||||||
MaskedLMOutput,
|
|
||||||
MultipleChoiceModelOutput,
|
|
||||||
NextSentencePredictorOutput,
|
|
||||||
QuestionAnsweringModelOutput,
|
|
||||||
SequenceClassifierOutput,
|
|
||||||
TokenClassifierOutput,
|
|
||||||
)
|
|
||||||
from transformers.modeling_utils import (
|
|
||||||
PreTrainedModel,
|
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.models.bert.configuration_bert import BertConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BertEmbeddings(nn.Module):
|
|
||||||
"""Construct the embeddings from word and position embeddings."""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
||||||
|
|
||||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
|
||||||
# any TensorFlow checkpoint file
|
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
||||||
|
|
||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
||||||
):
|
|
||||||
if input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
else:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
|
|
||||||
seq_length = input_shape[1]
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
embeddings = inputs_embeds
|
|
||||||
|
|
||||||
if self.position_embedding_type == "absolute":
|
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
|
||||||
embeddings += position_embeddings
|
|
||||||
embeddings = self.LayerNorm(embeddings)
|
|
||||||
embeddings = self.dropout(embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class BertSelfAttention(nn.Module):
|
|
||||||
def __init__(self, config, is_cross_attention):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
||||||
raise ValueError(
|
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
|
||||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
||||||
|
|
||||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
||||||
if is_cross_attention:
|
|
||||||
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
|
||||||
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
|
||||||
else:
|
|
||||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
||||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
||||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
|
||||||
self.save_attention = False
|
|
||||||
|
|
||||||
def save_attn_gradients(self, attn_gradients):
|
|
||||||
self.attn_gradients = attn_gradients
|
|
||||||
|
|
||||||
def get_attn_gradients(self):
|
|
||||||
return self.attn_gradients
|
|
||||||
|
|
||||||
def save_attention_map(self, attention_map):
|
|
||||||
self.attention_map = attention_map
|
|
||||||
|
|
||||||
def get_attention_map(self):
|
|
||||||
return self.attention_map
|
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
||||||
x = x.view(*new_x_shape)
|
|
||||||
return x.permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_value=None,
|
|
||||||
output_attentions=False,
|
|
||||||
):
|
|
||||||
mixed_query_layer = self.query(hidden_states)
|
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
|
||||||
# and values come from an encoder; the attention mask needs to be
|
|
||||||
# such that the encoder's padding tokens are not attended to.
|
|
||||||
is_cross_attention = encoder_hidden_states is not None
|
|
||||||
|
|
||||||
if is_cross_attention:
|
|
||||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
|
||||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
|
||||||
attention_mask = encoder_attention_mask
|
|
||||||
elif past_key_value is not None:
|
|
||||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
||||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
||||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
|
||||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
|
||||||
else:
|
|
||||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
||||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
||||||
|
|
||||||
past_key_value = (key_layer, value_layer)
|
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
||||||
|
|
||||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
||||||
seq_length = hidden_states.size()[1]
|
|
||||||
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
|
||||||
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
|
||||||
distance = position_ids_l - position_ids_r
|
|
||||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
|
||||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
|
||||||
|
|
||||||
if self.position_embedding_type == "relative_key":
|
|
||||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
||||||
attention_scores = attention_scores + relative_position_scores
|
|
||||||
elif self.position_embedding_type == "relative_key_query":
|
|
||||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
||||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
|
||||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
|
||||||
|
|
||||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
||||||
if attention_mask is not None:
|
|
||||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
||||||
attention_scores = attention_scores + attention_mask
|
|
||||||
|
|
||||||
# Normalize the attention scores to probabilities.
|
|
||||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
||||||
|
|
||||||
if is_cross_attention and self.save_attention:
|
|
||||||
self.save_attention_map(attention_probs)
|
|
||||||
attention_probs.register_hook(self.save_attn_gradients)
|
|
||||||
|
|
||||||
# This is actually dropping out entire tokens to attend to, which might
|
|
||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
||||||
attention_probs_dropped = self.dropout(attention_probs)
|
|
||||||
|
|
||||||
# Mask heads if we want to
|
|
||||||
if head_mask is not None:
|
|
||||||
attention_probs_dropped = attention_probs_dropped * head_mask
|
|
||||||
|
|
||||||
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
|
||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
|
||||||
|
|
||||||
outputs = outputs + (past_key_value,)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class BertSelfOutput(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
|
||||||
hidden_states = self.dense(hidden_states)
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class BertAttention(nn.Module):
|
|
||||||
def __init__(self, config, is_cross_attention=False):
|
|
||||||
super().__init__()
|
|
||||||
self.self = BertSelfAttention(config, is_cross_attention)
|
|
||||||
self.output = BertSelfOutput(config)
|
|
||||||
self.pruned_heads = set()
|
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
|
||||||
if len(heads) == 0:
|
|
||||||
return
|
|
||||||
heads, index = find_pruneable_heads_and_indices(
|
|
||||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prune linear layers
|
|
||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
|
||||||
self.self.key = prune_linear_layer(self.self.key, index)
|
|
||||||
self.self.value = prune_linear_layer(self.self.value, index)
|
|
||||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
||||||
|
|
||||||
# Update hyper params and store pruned heads
|
|
||||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
|
||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_value=None,
|
|
||||||
output_attentions=False,
|
|
||||||
):
|
|
||||||
self_outputs = self.self(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
)
|
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class BertIntermediate(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
||||||
if isinstance(config.hidden_act, str):
|
|
||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
||||||
else:
|
|
||||||
self.intermediate_act_fn = config.hidden_act
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.dense(hidden_states)
|
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class BertOutput(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
|
||||||
hidden_states = self.dense(hidden_states)
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
|
||||||
def __init__(self, config, layer_num):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
||||||
self.seq_len_dim = 1
|
|
||||||
self.attention = BertAttention(config)
|
|
||||||
self.layer_num = layer_num
|
|
||||||
if self.config.add_cross_attention:
|
|
||||||
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
|
||||||
self.intermediate = BertIntermediate(config)
|
|
||||||
self.output = BertOutput(config)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_value=None,
|
|
||||||
output_attentions=False,
|
|
||||||
mode=None,
|
|
||||||
):
|
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
|
||||||
self_attention_outputs = self.attention(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
head_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
past_key_value=self_attn_past_key_value,
|
|
||||||
)
|
|
||||||
attention_output = self_attention_outputs[0]
|
|
||||||
|
|
||||||
outputs = self_attention_outputs[1:-1]
|
|
||||||
present_key_value = self_attention_outputs[-1]
|
|
||||||
|
|
||||||
if mode=='multimodal':
|
|
||||||
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
|
||||||
|
|
||||||
cross_attention_outputs = self.crossattention(
|
|
||||||
attention_output,
|
|
||||||
attention_mask,
|
|
||||||
head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
attention_output = cross_attention_outputs[0]
|
|
||||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
|
||||||
layer_output = apply_chunking_to_forward(
|
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
|
||||||
)
|
|
||||||
outputs = (layer_output,) + outputs
|
|
||||||
|
|
||||||
outputs = outputs + (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
|
||||||
intermediate_output = self.intermediate(attention_output)
|
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
|
||||||
return layer_output
|
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_values=None,
|
|
||||||
use_cache=None,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
return_dict=True,
|
|
||||||
mode='multimodal',
|
|
||||||
):
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
||||||
|
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
for i in range(self.config.num_hidden_layers):
|
|
||||||
layer_module = self.layer[i]
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
logger.warn(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs, past_key_value, output_attentions)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(layer_module),
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
mode=mode,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = layer_module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
layer_head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
mode=mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[-1],)
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
||||||
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
|
||||||
next_decoder_cache,
|
|
||||||
all_hidden_states,
|
|
||||||
all_self_attentions,
|
|
||||||
all_cross_attentions,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_decoder_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attentions,
|
|
||||||
cross_attentions=all_cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BertPooler(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
self.activation = nn.Tanh()
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
|
||||||
# to the first token.
|
|
||||||
first_token_tensor = hidden_states[:, 0]
|
|
||||||
pooled_output = self.dense(first_token_tensor)
|
|
||||||
pooled_output = self.activation(pooled_output)
|
|
||||||
return pooled_output
|
|
||||||
|
|
||||||
|
|
||||||
class BertPredictionHeadTransform(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
if isinstance(config.hidden_act, str):
|
|
||||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
|
||||||
else:
|
|
||||||
self.transform_act_fn = config.hidden_act
|
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.dense(hidden_states)
|
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class BertLMPredictionHead(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.transform = BertPredictionHeadTransform(config)
|
|
||||||
|
|
||||||
# The output weights are the same as the input embeddings, but there is
|
|
||||||
# an output-only bias for each token.
|
|
||||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.transform(hidden_states)
|
|
||||||
hidden_states = self.decoder(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class BertOnlyMLMHead(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.predictions = BertLMPredictionHead(config)
|
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
|
||||||
prediction_scores = self.predictions(sequence_output)
|
|
||||||
return prediction_scores
|
|
||||||
|
|
||||||
|
|
||||||
class BertPreTrainedModel(PreTrainedModel):
|
|
||||||
"""
|
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
||||||
models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = BertConfig
|
|
||||||
base_model_prefix = "bert"
|
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
""" Initialize the weights """
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
|
|
||||||
class BertModel(BertPreTrainedModel):
|
|
||||||
"""
|
|
||||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
||||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
|
||||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
|
||||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
||||||
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
|
||||||
input to the forward pass.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, add_pooling_layer=True):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.embeddings = BertEmbeddings(config)
|
|
||||||
|
|
||||||
self.encoder = BertEncoder(config)
|
|
||||||
|
|
||||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
||||||
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.embeddings.word_embeddings
|
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
|
||||||
self.embeddings.word_embeddings = value
|
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
|
||||||
"""
|
|
||||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
||||||
class PreTrainedModel
|
|
||||||
"""
|
|
||||||
for layer, heads in heads_to_prune.items():
|
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
||||||
|
|
||||||
|
|
||||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
|
||||||
"""
|
|
||||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
attention_mask (:obj:`torch.Tensor`):
|
|
||||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
|
||||||
input_shape (:obj:`Tuple[int]`):
|
|
||||||
The shape of the input to the model.
|
|
||||||
device: (:obj:`torch.device`):
|
|
||||||
The device of the input to the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
|
||||||
"""
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
||||||
if attention_mask.dim() == 3:
|
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
|
||||||
elif attention_mask.dim() == 2:
|
|
||||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if is_decoder:
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
|
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
||||||
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
|
||||||
# causal and attention masks must have same type with pytorch version < 1.3
|
|
||||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
|
||||||
|
|
||||||
if causal_mask.shape[1] < attention_mask.shape[1]:
|
|
||||||
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
|
||||||
causal_mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
|
||||||
causal_mask,
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
|
||||||
input_shape, attention_mask.shape
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
return extended_attention_mask
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids=None,
|
|
||||||
attention_mask=None,
|
|
||||||
position_ids=None,
|
|
||||||
head_mask=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
encoder_embeds=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_values=None,
|
|
||||||
use_cache=None,
|
|
||||||
output_attentions=None,
|
|
||||||
output_hidden_states=None,
|
|
||||||
return_dict=None,
|
|
||||||
is_decoder=False,
|
|
||||||
mode='multimodal',
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
||||||
the model is configured as a decoder.
|
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
||||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
||||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
||||||
use_cache (:obj:`bool`, `optional`):
|
|
||||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
||||||
decoding (see :obj:`past_key_values`).
|
|
||||||
"""
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if is_decoder:
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
else:
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = input_ids.device
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = inputs_embeds.device
|
|
||||||
elif encoder_embeds is not None:
|
|
||||||
input_shape = encoder_embeds.size()[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = encoder_embeds.device
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
|
||||||
|
|
||||||
# past_key_values_length
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
|
||||||
device, is_decoder)
|
|
||||||
|
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if encoder_hidden_states is not None:
|
|
||||||
if type(encoder_hidden_states) == list:
|
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
|
||||||
else:
|
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
||||||
|
|
||||||
if type(encoder_attention_mask) == list:
|
|
||||||
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
|
||||||
elif encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
||||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
||||||
else:
|
|
||||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
||||||
else:
|
|
||||||
encoder_extended_attention_mask = None
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
||||||
|
|
||||||
if encoder_embeds is None:
|
|
||||||
embedding_output = self.embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding_output = encoder_embeds
|
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
embedding_output,
|
|
||||||
attention_mask=extended_attention_mask,
|
|
||||||
head_mask=head_mask,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
mode=mode,
|
|
||||||
)
|
|
||||||
sequence_output = encoder_outputs[0]
|
|
||||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
|
||||||
|
|
||||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
||||||
last_hidden_state=sequence_output,
|
|
||||||
pooler_output=pooled_output,
|
|
||||||
past_key_values=encoder_outputs.past_key_values,
|
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
|
||||||
attentions=encoder_outputs.attentions,
|
|
||||||
cross_attentions=encoder_outputs.cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BertLMHeadModel(BertPreTrainedModel):
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
self.bert = BertModel(config, add_pooling_layer=False)
|
|
||||||
self.cls = BertOnlyMLMHead(config)
|
|
||||||
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
return self.cls.predictions.decoder
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_embeddings):
|
|
||||||
self.cls.predictions.decoder = new_embeddings
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids=None,
|
|
||||||
attention_mask=None,
|
|
||||||
position_ids=None,
|
|
||||||
head_mask=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
labels=None,
|
|
||||||
past_key_values=None,
|
|
||||||
use_cache=None,
|
|
||||||
output_attentions=None,
|
|
||||||
output_hidden_states=None,
|
|
||||||
return_dict=None,
|
|
||||||
return_logits=False,
|
|
||||||
is_decoder=True,
|
|
||||||
reduction='mean',
|
|
||||||
mode='multimodal',
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
||||||
the model is configured as a decoder.
|
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
|
||||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
|
||||||
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
|
||||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
||||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
||||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
||||||
use_cache (:obj:`bool`, `optional`):
|
|
||||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
||||||
decoding (see :obj:`past_key_values`).
|
|
||||||
Returns:
|
|
||||||
Example::
|
|
||||||
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
|
||||||
>>> import torch
|
|
||||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
|
||||||
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
|
||||||
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
||||||
>>> outputs = model(**inputs)
|
|
||||||
>>> prediction_logits = outputs.logits
|
|
||||||
"""
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
if labels is not None:
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
outputs = self.bert(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
is_decoder=is_decoder,
|
|
||||||
mode=mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
|
||||||
prediction_scores = self.cls(sequence_output)
|
|
||||||
|
|
||||||
if return_logits:
|
|
||||||
return prediction_scores[:, :-1, :].contiguous()
|
|
||||||
|
|
||||||
lm_loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
|
||||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
|
||||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
||||||
if reduction=='none':
|
|
||||||
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (prediction_scores,) + outputs[2:]
|
|
||||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithCrossAttentions(
|
|
||||||
loss=lm_loss,
|
|
||||||
logits=prediction_scores,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
cross_attentions=outputs.cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
|
||||||
input_shape = input_ids.shape
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
|
||||||
if past is not None:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"past_key_values": past,
|
|
||||||
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
|
||||||
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
|
||||||
"is_decoder": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
@@ -1,301 +0,0 @@
|
|||||||
'''
|
|
||||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
|
||||||
* Based on timm code base
|
|
||||||
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
|
||||||
'''
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from timm.models.vision_transformer import _cfg, PatchEmbed
|
|
||||||
from timm.models.registry import register_model
|
|
||||||
from timm.models.layers import trunc_normal_, DropPath
|
|
||||||
from timm.models.helpers import named_apply, adapt_input_conv
|
|
||||||
|
|
||||||
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
|
||||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
|
||||||
"""
|
|
||||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
||||||
super().__init__()
|
|
||||||
out_features = out_features or in_features
|
|
||||||
hidden_features = hidden_features or in_features
|
|
||||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
||||||
self.act = act_layer()
|
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
||||||
self.drop = nn.Dropout(drop)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
|
||||||
self.scale = qk_scale or head_dim ** -0.5
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
self.attn_gradients = None
|
|
||||||
self.attention_map = None
|
|
||||||
|
|
||||||
def save_attn_gradients(self, attn_gradients):
|
|
||||||
self.attn_gradients = attn_gradients
|
|
||||||
|
|
||||||
def get_attn_gradients(self):
|
|
||||||
return self.attn_gradients
|
|
||||||
|
|
||||||
def save_attention_map(self, attention_map):
|
|
||||||
self.attention_map = attention_map
|
|
||||||
|
|
||||||
def get_attention_map(self):
|
|
||||||
return self.attention_map
|
|
||||||
|
|
||||||
def forward(self, x, register_hook=False):
|
|
||||||
B, N, C = x.shape
|
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
|
|
||||||
if register_hook:
|
|
||||||
self.save_attention_map(attn)
|
|
||||||
attn.register_hook(self.save_attn_gradients)
|
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
||||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = norm_layer(dim)
|
|
||||||
self.attn = Attention(
|
|
||||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
||||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
||||||
self.norm2 = norm_layer(dim)
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
# if use_grad_checkpointing:
|
|
||||||
# self.attn = checkpoint_wrapper(self.attn)
|
|
||||||
# self.mlp = checkpoint_wrapper(self.mlp)
|
|
||||||
|
|
||||||
def forward(self, x, register_hook=False):
|
|
||||||
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
|
||||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
|
||||||
""" Vision Transformer
|
|
||||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
|
||||||
https://arxiv.org/abs/2010.11929
|
|
||||||
"""
|
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
||||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
|
||||||
use_grad_checkpointing=False, ckpt_layer=0):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
img_size (int, tuple): input image size
|
|
||||||
patch_size (int, tuple): patch size
|
|
||||||
in_chans (int): number of input channels
|
|
||||||
num_classes (int): number of classes for classification head
|
|
||||||
embed_dim (int): embedding dimension
|
|
||||||
depth (int): depth of transformer
|
|
||||||
num_heads (int): number of attention heads
|
|
||||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
|
||||||
qkv_bias (bool): enable bias for qkv if True
|
|
||||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
|
||||||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
|
||||||
drop_rate (float): dropout rate
|
|
||||||
attn_drop_rate (float): attention dropout rate
|
|
||||||
drop_path_rate (float): stochastic depth rate
|
|
||||||
norm_layer: (nn.Module): normalization layer
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
||||||
|
|
||||||
self.patch_embed = PatchEmbed(
|
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
||||||
|
|
||||||
num_patches = self.patch_embed.num_patches
|
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
Block(
|
|
||||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
|
||||||
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
|
||||||
)
|
|
||||||
for i in range(depth)])
|
|
||||||
self.norm = norm_layer(embed_dim)
|
|
||||||
|
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
|
||||||
trunc_normal_(self.cls_token, std=.02)
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def no_weight_decay(self):
|
|
||||||
return {'pos_embed', 'cls_token'}
|
|
||||||
|
|
||||||
def forward(self, x, register_blk=-1):
|
|
||||||
B = x.shape[0]
|
|
||||||
x = self.patch_embed(x)
|
|
||||||
|
|
||||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
|
||||||
|
|
||||||
x = x + self.pos_embed[:,:x.size(1),:]
|
|
||||||
x = self.pos_drop(x)
|
|
||||||
|
|
||||||
for i,blk in enumerate(self.blocks):
|
|
||||||
x = blk(x, register_blk==i)
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.jit.ignore()
|
|
||||||
def load_pretrained(self, checkpoint_path, prefix=''):
|
|
||||||
_load_weights(self, checkpoint_path, prefix)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
|
||||||
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def _n2p(w, t=True):
|
|
||||||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
|
||||||
w = w.flatten()
|
|
||||||
if t:
|
|
||||||
if w.ndim == 4:
|
|
||||||
w = w.transpose([3, 2, 0, 1])
|
|
||||||
elif w.ndim == 3:
|
|
||||||
w = w.transpose([2, 0, 1])
|
|
||||||
elif w.ndim == 2:
|
|
||||||
w = w.transpose([1, 0])
|
|
||||||
return torch.from_numpy(w)
|
|
||||||
|
|
||||||
w = np.load(checkpoint_path)
|
|
||||||
if not prefix and 'opt/target/embedding/kernel' in w:
|
|
||||||
prefix = 'opt/target/'
|
|
||||||
|
|
||||||
if hasattr(model.patch_embed, 'backbone'):
|
|
||||||
# hybrid
|
|
||||||
backbone = model.patch_embed.backbone
|
|
||||||
stem_only = not hasattr(backbone, 'stem')
|
|
||||||
stem = backbone if stem_only else backbone.stem
|
|
||||||
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
|
||||||
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
|
||||||
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
|
||||||
if not stem_only:
|
|
||||||
for i, stage in enumerate(backbone.stages):
|
|
||||||
for j, block in enumerate(stage.blocks):
|
|
||||||
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
|
||||||
for r in range(3):
|
|
||||||
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
|
||||||
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
|
||||||
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
|
||||||
if block.downsample is not None:
|
|
||||||
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
|
||||||
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
|
||||||
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
|
||||||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
|
||||||
else:
|
|
||||||
embed_conv_w = adapt_input_conv(
|
|
||||||
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
|
||||||
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
|
||||||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
|
||||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
|
||||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
|
||||||
if pos_embed_w.shape != model.pos_embed.shape:
|
|
||||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
|
||||||
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
|
||||||
model.pos_embed.copy_(pos_embed_w)
|
|
||||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
|
||||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
|
||||||
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
|
||||||
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
|
||||||
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
|
||||||
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
|
||||||
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
|
||||||
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
|
||||||
for i, block in enumerate(model.blocks.children()):
|
|
||||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
|
||||||
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
|
||||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
|
||||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
|
||||||
block.attn.qkv.weight.copy_(torch.cat([
|
|
||||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
|
||||||
block.attn.qkv.bias.copy_(torch.cat([
|
|
||||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
|
||||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
|
||||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
|
||||||
for r in range(2):
|
|
||||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
|
||||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
|
||||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
|
||||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
|
||||||
# interpolate position embedding
|
|
||||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
||||||
num_patches = visual_encoder.patch_embed.num_patches
|
|
||||||
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
|
||||||
# height (== width) for the checkpoint position embedding
|
|
||||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
||||||
# height (== width) for the new position embedding
|
|
||||||
new_size = int(num_patches ** 0.5)
|
|
||||||
|
|
||||||
if orig_size!=new_size:
|
|
||||||
# class_token and dist_token are kept unchanged
|
|
||||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
||||||
# only the position tokens are interpolated
|
|
||||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
||||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
|
||||||
pos_tokens = torch.nn.functional.interpolate(
|
|
||||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
|
||||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
||||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
||||||
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
|
||||||
|
|
||||||
return new_pos_embed
|
|
||||||
else:
|
|
||||||
return pos_embed_checkpoint
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
from modelscope import snapshot_download
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
import os
|
|
||||||
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
|
||||||
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
|
||||||
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
|
||||||
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
|
||||||
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
|
||||||
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
|
||||||
|
|
||||||
|
|
||||||
preference_model_id: TypeAlias = Literal[
|
|
||||||
"ImageReward",
|
|
||||||
"Aesthetic",
|
|
||||||
"PickScore",
|
|
||||||
"CLIP",
|
|
||||||
"HPSv2",
|
|
||||||
"HPSv2.1",
|
|
||||||
"MPS",
|
|
||||||
]
|
|
||||||
model_dict = {
|
|
||||||
"ImageReward": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"ImageReward/ImageReward.safetensors",
|
|
||||||
"ImageReward/med_config.json",
|
|
||||||
"bert-base-uncased/config.json",
|
|
||||||
"bert-base-uncased/model.safetensors",
|
|
||||||
"bert-base-uncased/tokenizer.json",
|
|
||||||
"bert-base-uncased/tokenizer_config.json",
|
|
||||||
"bert-base-uncased/vocab.txt",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"imagereward": "ImageReward/ImageReward.safetensors",
|
|
||||||
"med_config": "ImageReward/med_config.json",
|
|
||||||
"bert_model_path": "bert-base-uncased",
|
|
||||||
},
|
|
||||||
"model_class": ImageRewardScore
|
|
||||||
},
|
|
||||||
"Aesthetic": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
|
||||||
"clip-vit-large-patch14/config.json",
|
|
||||||
"clip-vit-large-patch14/merges.txt",
|
|
||||||
"clip-vit-large-patch14/model.safetensors",
|
|
||||||
"clip-vit-large-patch14/preprocessor_config.json",
|
|
||||||
"clip-vit-large-patch14/special_tokens_map.json",
|
|
||||||
"clip-vit-large-patch14/tokenizer.json",
|
|
||||||
"clip-vit-large-patch14/tokenizer_config.json",
|
|
||||||
"clip-vit-large-patch14/vocab.json",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
|
||||||
"clip-large": "clip-vit-large-patch14",
|
|
||||||
},
|
|
||||||
"model_class": AestheticScore
|
|
||||||
},
|
|
||||||
"PickScore": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"PickScore_v1/*",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"pickscore": "PickScore_v1",
|
|
||||||
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
|
||||||
},
|
|
||||||
"model_class": PickScore
|
|
||||||
},
|
|
||||||
"CLIP": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
|
||||||
"bpe_simple_vocab_16e6.txt.gz",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
|
||||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
|
||||||
},
|
|
||||||
"model_class": CLIPScore
|
|
||||||
},
|
|
||||||
"HPSv2": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"HPS_v2/HPS_v2_compressed.safetensors",
|
|
||||||
"bpe_simple_vocab_16e6.txt.gz",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
|
||||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
|
||||||
},
|
|
||||||
"model_class": HPScore_v2,
|
|
||||||
"extra_kwargs": {"model_version": "v2"}
|
|
||||||
},
|
|
||||||
"HPSv2.1": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
|
||||||
"bpe_simple_vocab_16e6.txt.gz",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
|
||||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
|
||||||
},
|
|
||||||
"model_class": HPScore_v2,
|
|
||||||
"extra_kwargs": {"model_version": "v21"}
|
|
||||||
},
|
|
||||||
"MPS": {
|
|
||||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
|
||||||
"allow_file_pattern": [
|
|
||||||
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
|
||||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
|
||||||
],
|
|
||||||
"load_path": {
|
|
||||||
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
|
||||||
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
|
||||||
},
|
|
||||||
"model_class": MPScore
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
|
||||||
metadata = model_dict[model_name]
|
|
||||||
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
|
||||||
load_path = metadata["load_path"]
|
|
||||||
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
|
||||||
return load_path
|
|
||||||
|
|
||||||
|
|
||||||
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
|
||||||
model_class = model_dict[model_name]["model_class"]
|
|
||||||
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
|
||||||
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
|
||||||
return preference_model
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from transformers import AutoProcessor, AutoModel
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
import os
|
|
||||||
from typing import Union, List
|
|
||||||
from .config import MODEL_PATHS
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
|
|
||||||
super().__init__()
|
|
||||||
self.input_size = input_size
|
|
||||||
self.xcol = xcol
|
|
||||||
self.ycol = ycol
|
|
||||||
self.layers = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(self.input_size, 1024),
|
|
||||||
#torch.nn.ReLU(),
|
|
||||||
torch.nn.Dropout(0.2),
|
|
||||||
torch.nn.Linear(1024, 128),
|
|
||||||
#torch.nn.ReLU(),
|
|
||||||
torch.nn.Dropout(0.2),
|
|
||||||
torch.nn.Linear(128, 64),
|
|
||||||
#torch.nn.ReLU(),
|
|
||||||
torch.nn.Dropout(0.1),
|
|
||||||
torch.nn.Linear(64, 16),
|
|
||||||
#torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(16, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
|
||||||
x = batch[self.xcol]
|
|
||||||
y = batch[self.ycol].reshape(-1, 1)
|
|
||||||
x_hat = self.layers(x)
|
|
||||||
loss = torch.nn.functional.mse_loss(x_hat, y)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
|
||||||
x = batch[self.xcol]
|
|
||||||
y = batch[self.ycol].reshape(-1, 1)
|
|
||||||
x_hat = self.layers(x)
|
|
||||||
loss = torch.nn.functional.mse_loss(x_hat, y)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def configure_optimizers(self) -> torch.optim.Optimizer:
|
|
||||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
class AestheticScore(torch.nn.Module):
|
|
||||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.aes_model_path = path.get("aesthetic_predictor")
|
|
||||||
# Load the MLP model
|
|
||||||
self.model = MLP(768)
|
|
||||||
try:
|
|
||||||
if self.aes_model_path.endswith(".safetensors"):
|
|
||||||
state_dict = load_file(self.aes_model_path)
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(self.aes_model_path)
|
|
||||||
self.model.load_state_dict(state_dict)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
|
||||||
|
|
||||||
self.model.to(device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# Load the CLIP model and processor
|
|
||||||
clip_model_name = path.get('clip-large')
|
|
||||||
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
|
||||||
|
|
||||||
def _calculate_score(self, image: torch.Tensor) -> float:
|
|
||||||
"""Calculate the aesthetic score for a single image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The aesthetic score.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
# Get image embeddings
|
|
||||||
image_embs = self.model2.get_image_features(image)
|
|
||||||
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Compute score
|
|
||||||
score = self.model(image_embs).cpu().flatten().item()
|
|
||||||
|
|
||||||
return score
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
|
||||||
"""Score the images based on their aesthetic quality.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of scores for the images.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if isinstance(images, (str, Image.Image)):
|
|
||||||
# Single image
|
|
||||||
if isinstance(images, str):
|
|
||||||
pil_image = Image.open(images)
|
|
||||||
else:
|
|
||||||
pil_image = images
|
|
||||||
|
|
||||||
# Prepare image inputs
|
|
||||||
image_inputs = self.processor(
|
|
||||||
images=pil_image,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
return [self._calculate_score(image_inputs["pixel_values"])]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_image in images:
|
|
||||||
if isinstance(one_image, str):
|
|
||||||
pil_image = Image.open(one_image)
|
|
||||||
elif isinstance(one_image, Image.Image):
|
|
||||||
pil_image = one_image
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
|
|
||||||
# Prepare image inputs
|
|
||||||
image_inputs = self.processor(
|
|
||||||
images=pil_image,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
scores.append(self._calculate_score(image_inputs["pixel_values"]))
|
|
||||||
return scores
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error in scoring images: {e}")
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
from typing import List, Union
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from .open_clip import create_model_and_transforms, get_tokenizer
|
|
||||||
from .config import MODEL_PATHS
|
|
||||||
|
|
||||||
class CLIPScore(torch.nn.Module):
|
|
||||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
|
||||||
super().__init__()
|
|
||||||
"""Initialize the CLIPScore with a model and tokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (torch.device): The device to load the model on.
|
|
||||||
"""
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
# Create model and transforms
|
|
||||||
self.model, _, self.preprocess_val = create_model_and_transforms(
|
|
||||||
"ViT-H-14",
|
|
||||||
# "laion2B-s32B-b79K",
|
|
||||||
pretrained=path.get("open_clip"),
|
|
||||||
precision="amp",
|
|
||||||
device=device,
|
|
||||||
jit=False,
|
|
||||||
force_quick_gelu=False,
|
|
||||||
force_custom_text=False,
|
|
||||||
force_patch_dropout=False,
|
|
||||||
force_image_size=None,
|
|
||||||
pretrained_image=False,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
light_augmentation=True,
|
|
||||||
aug_cfg={},
|
|
||||||
output_dict=True,
|
|
||||||
with_score_predictor=False,
|
|
||||||
with_region_predictor=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize tokenizer
|
|
||||||
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
|
||||||
"""Calculate the CLIP score for a single image and prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The CLIP score.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
# Process the prompt
|
|
||||||
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# Calculate the CLIP score
|
|
||||||
outputs = self.model(image, text)
|
|
||||||
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
|
||||||
logits_per_image = image_features @ text_features.T
|
|
||||||
clip_score = torch.diagonal(logits_per_image).cpu().numpy()
|
|
||||||
|
|
||||||
return clip_score[0].item()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
|
||||||
"""Score the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of CLIP scores for the images.
|
|
||||||
"""
|
|
||||||
if isinstance(images, (str, Image.Image)):
|
|
||||||
# Single image
|
|
||||||
if isinstance(images, str):
|
|
||||||
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
return [self._calculate_score(image, prompt)]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_images in images:
|
|
||||||
if isinstance(one_images, str):
|
|
||||||
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
elif isinstance(one_images, Image.Image):
|
|
||||||
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
scores.append(self._calculate_score(image, prompt))
|
|
||||||
return scores
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
|
||||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name):
|
|
||||||
return os.path.join(model_path, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATHS = {
|
|
||||||
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
|
||||||
"open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
|
||||||
"hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
|
|
||||||
"hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
|
|
||||||
"imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
|
|
||||||
"med_config": get_model_path("ImageReward/med_config.json"),
|
|
||||||
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
|
||||||
"clip-large": get_model_path("clip-vit-large-patch14"),
|
|
||||||
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
|
||||||
"pickscore": get_model_path("PickScore_v1")
|
|
||||||
}
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
from typing import List, Union
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from .open_clip import create_model_and_transforms, get_tokenizer
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
import os
|
|
||||||
from .config import MODEL_PATHS
|
|
||||||
|
|
||||||
class HPScore_v2(torch.nn.Module):
|
|
||||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
|
||||||
super().__init__()
|
|
||||||
"""Initialize the Selector with a model and tokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (torch.device): The device to load the model on.
|
|
||||||
model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
|
|
||||||
"""
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
if model_version == "v2":
|
|
||||||
safetensors_path = path.get("hpsv2")
|
|
||||||
elif model_version == "v21":
|
|
||||||
safetensors_path = path.get("hpsv2.1")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
|
||||||
|
|
||||||
# Create model and transforms
|
|
||||||
model, _, self.preprocess_val = create_model_and_transforms(
|
|
||||||
"ViT-H-14",
|
|
||||||
# "laion2B-s32B-b79K",
|
|
||||||
pretrained=path.get("open_clip"),
|
|
||||||
precision="amp",
|
|
||||||
device=device,
|
|
||||||
jit=False,
|
|
||||||
force_quick_gelu=False,
|
|
||||||
force_custom_text=False,
|
|
||||||
force_patch_dropout=False,
|
|
||||||
force_image_size=None,
|
|
||||||
pretrained_image=False,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
light_augmentation=True,
|
|
||||||
aug_cfg={},
|
|
||||||
output_dict=True,
|
|
||||||
with_score_predictor=False,
|
|
||||||
with_region_predictor=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model weights
|
|
||||||
try:
|
|
||||||
state_dict = load_file(safetensors_path)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
|
||||||
|
|
||||||
# Initialize tokenizer and model
|
|
||||||
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
|
||||||
model = model.to(device)
|
|
||||||
model.eval()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
|
||||||
"""Calculate the HPS score for a single image and prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The HPS score.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
# Process the prompt
|
|
||||||
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# Calculate the HPS score
|
|
||||||
outputs = self.model(image, text)
|
|
||||||
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
|
||||||
logits_per_image = image_features @ text_features.T
|
|
||||||
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
|
||||||
|
|
||||||
return hps_score[0].item()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
|
||||||
"""Score the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of HPS scores for the images.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if isinstance(images, (str, Image.Image)):
|
|
||||||
# Single image
|
|
||||||
if isinstance(images, str):
|
|
||||||
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
return [self._calculate_score(image, prompt)]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_images in images:
|
|
||||||
if isinstance(one_images, str):
|
|
||||||
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
elif isinstance(one_images, Image.Image):
|
|
||||||
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
scores.append(self._calculate_score(image, prompt))
|
|
||||||
return scores
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error in scoring images: {e}")
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from typing import List, Union
|
|
||||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
|
||||||
from .BLIP.blip_pretrain import BLIP_Pretrain
|
|
||||||
from torchvision.transforms import InterpolationMode
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from .config import MODEL_PATHS
|
|
||||||
BICUBIC = InterpolationMode.BICUBIC
|
|
||||||
|
|
||||||
def _convert_image_to_rgb(image):
|
|
||||||
return image.convert("RGB")
|
|
||||||
|
|
||||||
def _transform(n_px):
|
|
||||||
return Compose([
|
|
||||||
Resize(n_px, interpolation=BICUBIC),
|
|
||||||
CenterCrop(n_px),
|
|
||||||
_convert_image_to_rgb,
|
|
||||||
ToTensor(),
|
|
||||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
|
||||||
])
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
def __init__(self, input_size):
|
|
||||||
super().__init__()
|
|
||||||
self.input_size = input_size
|
|
||||||
|
|
||||||
self.layers = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(self.input_size, 1024),
|
|
||||||
#nn.ReLU(),
|
|
||||||
torch.nn.Dropout(0.2),
|
|
||||||
torch.nn.Linear(1024, 128),
|
|
||||||
#nn.ReLU(),
|
|
||||||
torch.nn.Dropout(0.2),
|
|
||||||
torch.nn.Linear(128, 64),
|
|
||||||
#nn.ReLU(),
|
|
||||||
torch.nn.Dropout(0.1),
|
|
||||||
torch.nn.Linear(64, 16),
|
|
||||||
#nn.ReLU(),
|
|
||||||
torch.nn.Linear(16, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# initial MLP param
|
|
||||||
for name, param in self.layers.named_parameters():
|
|
||||||
if 'weight' in name:
|
|
||||||
torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
|
|
||||||
if 'bias' in name:
|
|
||||||
torch.nn.init.constant_(param, val=0)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.layers(input)
|
|
||||||
|
|
||||||
class ImageReward(torch.nn.Module):
|
|
||||||
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
|
||||||
self.preprocess = _transform(224)
|
|
||||||
self.mlp = MLP(768)
|
|
||||||
|
|
||||||
self.mean = 0.16717362830052426
|
|
||||||
self.std = 1.0333394966054072
|
|
||||||
|
|
||||||
def score_grad(self, prompt_ids, prompt_attention_mask, image):
|
|
||||||
"""Calculate the score with gradient for a single image and prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_ids (torch.Tensor): Tokenized prompt IDs.
|
|
||||||
prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The reward score.
|
|
||||||
"""
|
|
||||||
image_embeds = self.blip.visual_encoder(image)
|
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
||||||
text_output = self.blip.text_encoder(
|
|
||||||
prompt_ids,
|
|
||||||
attention_mask=prompt_attention_mask,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_atts,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
txt_features = text_output.last_hidden_state[:, 0, :]
|
|
||||||
rewards = self.mlp(txt_features)
|
|
||||||
rewards = (rewards - self.mean) / self.std
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
|
||||||
"""Score the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of scores for the images.
|
|
||||||
"""
|
|
||||||
if isinstance(images, (str, Image.Image)):
|
|
||||||
# Single image
|
|
||||||
if isinstance(images, str):
|
|
||||||
pil_image = Image.open(images)
|
|
||||||
else:
|
|
||||||
pil_image = images
|
|
||||||
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
||||||
return [self._calculate_score(prompt, image).item()]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_image in images:
|
|
||||||
if isinstance(one_image, str):
|
|
||||||
pil_image = Image.open(one_image)
|
|
||||||
elif isinstance(one_image, Image.Image):
|
|
||||||
pil_image = one_image
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
||||||
scores.append(self._calculate_score(prompt, image).item())
|
|
||||||
return scores
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
|
|
||||||
def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Calculate the score for a single image and prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The reward score.
|
|
||||||
"""
|
|
||||||
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
|
||||||
image_embeds = self.blip.visual_encoder(image)
|
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
||||||
text_output = self.blip.text_encoder(
|
|
||||||
text_input.input_ids,
|
|
||||||
attention_mask=text_input.attention_mask,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_atts,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
txt_features = text_output.last_hidden_state[:, 0, :].float()
|
|
||||||
rewards = self.mlp(txt_features)
|
|
||||||
rewards = (rewards - self.mean) / self.std
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
|
|
||||||
"""Rank the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
|
|
||||||
"""
|
|
||||||
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
|
||||||
txt_set = []
|
|
||||||
for generation in generations_list:
|
|
||||||
if isinstance(generation, str):
|
|
||||||
pil_image = Image.open(generation)
|
|
||||||
elif isinstance(generation, Image.Image):
|
|
||||||
pil_image = generation
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter generations_list is illegal.")
|
|
||||||
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
||||||
image_embeds = self.blip.visual_encoder(image)
|
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
||||||
text_output = self.blip.text_encoder(
|
|
||||||
text_input.input_ids,
|
|
||||||
attention_mask=text_input.attention_mask,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_atts,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
|
||||||
txt_features = torch.cat(txt_set, 0).float()
|
|
||||||
rewards = self.mlp(txt_features)
|
|
||||||
rewards = (rewards - self.mean) / self.std
|
|
||||||
rewards = torch.squeeze(rewards)
|
|
||||||
_, rank = torch.sort(rewards, dim=0, descending=True)
|
|
||||||
_, indices = torch.sort(rank, dim=0)
|
|
||||||
indices = indices + 1
|
|
||||||
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRewardScore(torch.nn.Module):
|
|
||||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
|
||||||
model_path = path.get("imagereward")
|
|
||||||
med_config = path.get("med_config")
|
|
||||||
state_dict = load_file(model_path)
|
|
||||||
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
|
||||||
self.model.load_state_dict(state_dict, strict=False)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
|
||||||
"""Score the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of scores for the images.
|
|
||||||
"""
|
|
||||||
return self.model.score(images, prompt)
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
|
||||||
from transformers import CLIPConfig
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from transformers import CLIPModel as HFCLIPModel
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torch import nn, einsum
|
|
||||||
|
|
||||||
from .trainer.models.base_model import BaseModelConfig
|
|
||||||
|
|
||||||
from transformers import CLIPConfig
|
|
||||||
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
|
||||||
from typing import Any, Optional, Tuple, Union, List
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .trainer.models.cross_modeling import Cross_model
|
|
||||||
from .trainer.models import clip_model
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import gc
|
|
||||||
import json
|
|
||||||
from .config import MODEL_PATHS
|
|
||||||
|
|
||||||
class MPScore(torch.nn.Module):
|
|
||||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
|
||||||
super().__init__()
|
|
||||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (Union[str, torch.device]): The device to load the model on.
|
|
||||||
"""
|
|
||||||
self.device = device
|
|
||||||
processor_name_or_path = path.get("clip")
|
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
|
||||||
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
|
||||||
state_dict = load_file(path.get("mps"))
|
|
||||||
self.model.load_state_dict(state_dict, strict=False)
|
|
||||||
self.model.to(device)
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
|
||||||
"""Calculate the reward score for a single image and prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The reward score.
|
|
||||||
"""
|
|
||||||
def _tokenize(caption):
|
|
||||||
input_ids = self.tokenizer(
|
|
||||||
caption,
|
|
||||||
max_length=self.tokenizer.model_max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt"
|
|
||||||
).input_ids
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
text_input = _tokenize(prompt).to(self.device)
|
|
||||||
if self.condition == 'overall':
|
|
||||||
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
|
|
||||||
elif self.condition == 'aesthetics':
|
|
||||||
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
|
|
||||||
elif self.condition == 'quality':
|
|
||||||
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
|
|
||||||
elif self.condition == 'semantic':
|
|
||||||
condition_prompt = 'quantity, attributes, position, number, location'
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
|
|
||||||
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
text_f, text_features = self.model.model.get_text_features(text_input)
|
|
||||||
|
|
||||||
image_f = self.model.model.get_image_features(image.half())
|
|
||||||
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
|
||||||
|
|
||||||
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
|
||||||
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
|
||||||
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
|
||||||
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
|
||||||
mask = mask.repeat(1, image_f.shape[1], 1)
|
|
||||||
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
|
||||||
|
|
||||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
|
||||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
||||||
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
|
||||||
|
|
||||||
return image_score[0].cpu().numpy().item()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
|
||||||
"""Score the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of reward scores for the images.
|
|
||||||
"""
|
|
||||||
if isinstance(images, (str, Image.Image)):
|
|
||||||
# Single image
|
|
||||||
if isinstance(images, str):
|
|
||||||
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
else:
|
|
||||||
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
return [self._calculate_score(image, prompt)]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_images in images:
|
|
||||||
if isinstance(one_images, str):
|
|
||||||
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
elif isinstance(one_images, Image.Image):
|
|
||||||
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
scores.append(self._calculate_score(image, prompt))
|
|
||||||
return scores
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
from .coca_model import CoCa
|
|
||||||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
|
||||||
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
|
||||||
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
|
||||||
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
|
||||||
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
|
||||||
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
|
||||||
from .openai import load_openai_model, list_openai_models
|
|
||||||
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
|
||||||
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
|
||||||
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
|
||||||
from .tokenizer import SimpleTokenizer
|
|
||||||
from .transform import image_transform, AugmentationCfg
|
|
||||||
from .utils import freeze_batch_norm_2d
|
|
||||||
@@ -1,458 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import numpy as np
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from .transformer import (
|
|
||||||
LayerNormFp32,
|
|
||||||
LayerNorm,
|
|
||||||
QuickGELU,
|
|
||||||
MultimodalTransformer,
|
|
||||||
)
|
|
||||||
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers import (
|
|
||||||
BeamSearchScorer,
|
|
||||||
LogitsProcessorList,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
RepetitionPenaltyLogitsProcessor,
|
|
||||||
MinLengthLogitsProcessor,
|
|
||||||
MaxLengthCriteria,
|
|
||||||
StoppingCriteriaList
|
|
||||||
)
|
|
||||||
|
|
||||||
GENERATION_TYPES = {
|
|
||||||
"top_k": TopKLogitsWarper,
|
|
||||||
"top_p": TopPLogitsWarper,
|
|
||||||
"beam_search": "beam_search"
|
|
||||||
}
|
|
||||||
_has_transformers = True
|
|
||||||
except ImportError as e:
|
|
||||||
GENERATION_TYPES = {
|
|
||||||
"top_k": None,
|
|
||||||
"top_p": None,
|
|
||||||
"beam_search": "beam_search"
|
|
||||||
}
|
|
||||||
_has_transformers = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MultimodalCfg(CLIPTextCfg):
|
|
||||||
mlp_ratio: int = 4
|
|
||||||
dim_head: int = 64
|
|
||||||
heads: int = 8
|
|
||||||
n_queries: int = 256
|
|
||||||
attn_pooler_heads: int = 8
|
|
||||||
|
|
||||||
|
|
||||||
def _build_text_decoder_tower(
|
|
||||||
embed_dim,
|
|
||||||
multimodal_cfg,
|
|
||||||
quick_gelu: bool = False,
|
|
||||||
cast_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
|
||||||
act_layer = QuickGELU if quick_gelu else nn.GELU
|
|
||||||
norm_layer = (
|
|
||||||
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder = MultimodalTransformer(
|
|
||||||
context_length=multimodal_cfg.context_length,
|
|
||||||
width=multimodal_cfg.width,
|
|
||||||
heads=multimodal_cfg.heads,
|
|
||||||
layers=multimodal_cfg.layers,
|
|
||||||
ls_init_value=multimodal_cfg.ls_init_value,
|
|
||||||
output_dim=embed_dim,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
class CoCa(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim,
|
|
||||||
multimodal_cfg: MultimodalCfg,
|
|
||||||
text_cfg: CLIPTextCfg,
|
|
||||||
vision_cfg: CLIPVisionCfg,
|
|
||||||
quick_gelu: bool = False,
|
|
||||||
cast_dtype: Optional[torch.dtype] = None,
|
|
||||||
pad_id: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
|
||||||
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
|
||||||
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
|
||||||
|
|
||||||
self.text = _build_text_tower(
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
text_cfg=text_cfg,
|
|
||||||
quick_gelu=quick_gelu,
|
|
||||||
cast_dtype=cast_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
vocab_size = (
|
|
||||||
text_cfg.vocab_size # for hf models
|
|
||||||
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
|
||||||
else text_cfg.vocab_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.visual = _build_vision_tower(
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
vision_cfg=vision_cfg,
|
|
||||||
quick_gelu=quick_gelu,
|
|
||||||
cast_dtype=cast_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.text_decoder = _build_text_decoder_tower(
|
|
||||||
vocab_size,
|
|
||||||
multimodal_cfg=multimodal_cfg,
|
|
||||||
quick_gelu=quick_gelu,
|
|
||||||
cast_dtype=cast_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
||||||
self.pad_id = pad_id
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.visual.set_grad_checkpointing(enable)
|
|
||||||
self.text.set_grad_checkpointing(enable)
|
|
||||||
self.text_decoder.set_grad_checkpointing(enable)
|
|
||||||
|
|
||||||
def _encode_image(self, images, normalize=True):
|
|
||||||
image_latent, tokens_embs = self.visual(images)
|
|
||||||
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
|
||||||
return image_latent, tokens_embs
|
|
||||||
|
|
||||||
def _encode_text(self, text, normalize=True, embed_cls=True):
|
|
||||||
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
|
||||||
text_latent, token_emb = self.text(text)
|
|
||||||
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
|
||||||
return text_latent, token_emb
|
|
||||||
|
|
||||||
def encode_image(self, images, normalize=True):
|
|
||||||
image_latent, _ = self._encode_image(images, normalize=normalize)
|
|
||||||
return image_latent
|
|
||||||
|
|
||||||
def encode_text(self, text, normalize=True, embed_cls=True):
|
|
||||||
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
|
||||||
return text_latent
|
|
||||||
|
|
||||||
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
|
||||||
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
|
||||||
if image_latent is None or image_embs is None:
|
|
||||||
image_latent, image_embs = self._encode_image(image)
|
|
||||||
|
|
||||||
# TODO: add assertion to avoid bugs?
|
|
||||||
labels = text[:, -token_embs.shape[1]:]
|
|
||||||
|
|
||||||
logits = self.text_decoder(image_embs, token_embs)
|
|
||||||
return {
|
|
||||||
"image_features": image_latent,
|
|
||||||
"text_features": text_latent,
|
|
||||||
"logits": logits,
|
|
||||||
"labels": labels,
|
|
||||||
"logit_scale": self.logit_scale.exp()
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
image,
|
|
||||||
text=None,
|
|
||||||
seq_len=30,
|
|
||||||
max_seq_len=77,
|
|
||||||
temperature=1.,
|
|
||||||
generation_type="beam_search",
|
|
||||||
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
|
||||||
top_k=1, # keeps the top_k most probable tokens
|
|
||||||
pad_token_id=None,
|
|
||||||
eos_token_id=None,
|
|
||||||
sot_token_id=None,
|
|
||||||
num_beams=6,
|
|
||||||
num_beam_groups=3,
|
|
||||||
min_seq_len=5,
|
|
||||||
stopping_criteria=None,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
|
||||||
):
|
|
||||||
# taking many ideas and components from HuggingFace GenerationMixin
|
|
||||||
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
|
||||||
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
|
||||||
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
|
||||||
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
|
||||||
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
|
||||||
logit_processor = LogitsProcessorList(
|
|
||||||
[
|
|
||||||
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
|
||||||
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if stopping_criteria is None:
|
|
||||||
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteriaList(
|
|
||||||
stopping_criteria
|
|
||||||
)
|
|
||||||
|
|
||||||
device = image.device
|
|
||||||
|
|
||||||
if generation_type == "beam_search":
|
|
||||||
output = self._generate_beamsearch(
|
|
||||||
image_inputs = image,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
sot_token_id=sot_token_id,
|
|
||||||
num_beams=num_beams,
|
|
||||||
num_beam_groups=num_beam_groups,
|
|
||||||
min_seq_len=min_seq_len,
|
|
||||||
stopping_criteria=stopping_criteria,
|
|
||||||
logit_processor=logit_processor,
|
|
||||||
)
|
|
||||||
if fixed_output_length and output.shape[1] < seq_len:
|
|
||||||
return torch.cat(
|
|
||||||
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
|
||||||
dim=1
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
elif generation_type == "top_p":
|
|
||||||
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
|
||||||
elif generation_type == "top_k":
|
|
||||||
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"generation_type has to be one of "
|
|
||||||
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
|
||||||
)
|
|
||||||
|
|
||||||
image_latent, image_embs = self._encode_image(image)
|
|
||||||
|
|
||||||
if text is None:
|
|
||||||
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
|
||||||
|
|
||||||
was_training = self.training
|
|
||||||
num_dims = len(text.shape)
|
|
||||||
|
|
||||||
if num_dims == 1:
|
|
||||||
text = text[None, :]
|
|
||||||
|
|
||||||
cur_len = text.shape[1]
|
|
||||||
self.eval()
|
|
||||||
out = text
|
|
||||||
|
|
||||||
while True:
|
|
||||||
x = out[:, -max_seq_len:]
|
|
||||||
cur_len = x.shape[1]
|
|
||||||
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
|
||||||
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
|
||||||
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
|
||||||
|
|
||||||
if mask.all():
|
|
||||||
if not fixed_output_length:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logits = logits[~mask, :]
|
|
||||||
filtered_logits = logit_processor(x[~mask, :], logits)
|
|
||||||
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
|
||||||
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
|
||||||
|
|
||||||
if (cur_len + 1 == seq_len):
|
|
||||||
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
|
||||||
else:
|
|
||||||
sample[~mask, :] = torch.multinomial(probs, 1)
|
|
||||||
|
|
||||||
out = torch.cat((out, sample), dim=-1)
|
|
||||||
|
|
||||||
cur_len += 1
|
|
||||||
|
|
||||||
if stopping_criteria(out, None):
|
|
||||||
break
|
|
||||||
|
|
||||||
if num_dims == 1:
|
|
||||||
out = out.squeeze(0)
|
|
||||||
|
|
||||||
self.train(was_training)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _generate_beamsearch(
|
|
||||||
self,
|
|
||||||
image_inputs,
|
|
||||||
pad_token_id=None,
|
|
||||||
eos_token_id=None,
|
|
||||||
sot_token_id=None,
|
|
||||||
num_beams=6,
|
|
||||||
num_beam_groups=3,
|
|
||||||
min_seq_len=5,
|
|
||||||
stopping_criteria=None,
|
|
||||||
logit_processor=None,
|
|
||||||
logit_warper=None,
|
|
||||||
):
|
|
||||||
device = image_inputs.device
|
|
||||||
batch_size = image_inputs.shape[0]
|
|
||||||
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
|
||||||
image_latent, image_embs = self._encode_image(image_inputs)
|
|
||||||
|
|
||||||
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
|
||||||
input_ids = input_ids * sot_token_id
|
|
||||||
beam_scorer = BeamSearchScorer(
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_beams=num_beams,
|
|
||||||
device=device,
|
|
||||||
num_beam_groups=num_beam_groups,
|
|
||||||
)
|
|
||||||
# instantiate logits processors
|
|
||||||
logits_processor = (
|
|
||||||
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
|
||||||
if logit_processor is None
|
|
||||||
else logit_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
|
||||||
num_beams = beam_scorer.num_beams
|
|
||||||
num_beam_groups = beam_scorer.num_beam_groups
|
|
||||||
num_sub_beams = num_beams // num_beam_groups
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
|
||||||
beam_indices = None
|
|
||||||
|
|
||||||
if num_beams * batch_size != batch_beam_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
|
||||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
|
||||||
# the same group don't produce same tokens everytime.
|
|
||||||
beam_scores[:, ::num_sub_beams] = 0
|
|
||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
|
|
||||||
# predicted tokens in cur_len step
|
|
||||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
|
||||||
|
|
||||||
# indices which will form the beams in the next time step
|
|
||||||
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
# do one decoder step on all beams of all sentences in batch
|
|
||||||
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
|
||||||
outputs = self(
|
|
||||||
model_inputs['images'],
|
|
||||||
model_inputs['text'],
|
|
||||||
embed_cls=False,
|
|
||||||
image_latent=image_latent,
|
|
||||||
image_embs=image_embs
|
|
||||||
)
|
|
||||||
|
|
||||||
for beam_group_idx in range(num_beam_groups):
|
|
||||||
group_start_idx = beam_group_idx * num_sub_beams
|
|
||||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
|
||||||
group_size = group_end_idx - group_start_idx
|
|
||||||
|
|
||||||
# indices of beams of current group among all sentences in batch
|
|
||||||
batch_group_indices = []
|
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
batch_group_indices.extend(
|
|
||||||
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
|
||||||
)
|
|
||||||
group_input_ids = input_ids[batch_group_indices]
|
|
||||||
|
|
||||||
# select outputs of beams of currentg group only
|
|
||||||
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
|
||||||
vocab_size = next_token_logits.shape[-1]
|
|
||||||
|
|
||||||
next_token_scores_processed = logits_processor(
|
|
||||||
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
|
||||||
)
|
|
||||||
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
|
||||||
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
|
||||||
|
|
||||||
# reshape for beam search
|
|
||||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
|
||||||
|
|
||||||
next_token_scores, next_tokens = torch.topk(
|
|
||||||
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
|
||||||
)
|
|
||||||
|
|
||||||
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
|
||||||
next_tokens = next_tokens % vocab_size
|
|
||||||
|
|
||||||
# stateless
|
|
||||||
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
|
||||||
beam_outputs = beam_scorer.process(
|
|
||||||
group_input_ids,
|
|
||||||
next_token_scores,
|
|
||||||
next_tokens,
|
|
||||||
next_indices,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
beam_indices=process_beam_indices,
|
|
||||||
)
|
|
||||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
|
||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
|
||||||
beam_idx = beam_outputs["next_beam_indices"]
|
|
||||||
|
|
||||||
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
|
||||||
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
|
||||||
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
|
||||||
|
|
||||||
# (beam_idx // group_size) -> batch_idx
|
|
||||||
# (beam_idx % group_size) -> offset of idx inside the group
|
|
||||||
reordering_indices[batch_group_indices] = (
|
|
||||||
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
|
||||||
|
|
||||||
# increase cur_len
|
|
||||||
cur_len = cur_len + 1
|
|
||||||
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
|
||||||
break
|
|
||||||
|
|
||||||
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
|
||||||
sequence_outputs = beam_scorer.finalize(
|
|
||||||
input_ids,
|
|
||||||
beam_scores,
|
|
||||||
next_tokens,
|
|
||||||
next_indices,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
max_length=stopping_criteria.max_length,
|
|
||||||
beam_indices=final_beam_indices,
|
|
||||||
)
|
|
||||||
return sequence_outputs['sequences']
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
|
||||||
if past:
|
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
|
||||||
# create position_ids on the fly for batch generation
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
else:
|
|
||||||
position_ids = None
|
|
||||||
return {
|
|
||||||
"text": input_ids,
|
|
||||||
"images": image_inputs,
|
|
||||||
"past_key_values": past,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
|
||||||
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
|
||||||
@@ -1,433 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
import re
|
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
|
||||||
# from turtle import forward
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
|
||||||
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
|
||||||
resize_pos_embed, get_cast_dtype
|
|
||||||
from .coca_model import CoCa
|
|
||||||
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
|
||||||
from .openai import load_openai_model
|
|
||||||
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
|
||||||
from .transform import image_transform, AugmentationCfg
|
|
||||||
from .tokenizer import HFTokenizer, SimpleTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
HF_HUB_PREFIX = 'hf-hub:'
|
|
||||||
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
|
||||||
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
|
||||||
|
|
||||||
|
|
||||||
def _natural_key(string_):
|
|
||||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
|
||||||
|
|
||||||
|
|
||||||
def _rescan_model_configs():
|
|
||||||
global _MODEL_CONFIGS
|
|
||||||
|
|
||||||
config_ext = ('.json',)
|
|
||||||
config_files = []
|
|
||||||
for config_path in _MODEL_CONFIG_PATHS:
|
|
||||||
if config_path.is_file() and config_path.suffix in config_ext:
|
|
||||||
config_files.append(config_path)
|
|
||||||
elif config_path.is_dir():
|
|
||||||
for ext in config_ext:
|
|
||||||
config_files.extend(config_path.glob(f'*{ext}'))
|
|
||||||
|
|
||||||
for cf in config_files:
|
|
||||||
with open(cf, 'r') as f:
|
|
||||||
model_cfg = json.load(f)
|
|
||||||
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
|
||||||
_MODEL_CONFIGS[cf.stem] = model_cfg
|
|
||||||
|
|
||||||
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
|
||||||
|
|
||||||
|
|
||||||
_rescan_model_configs() # initial populate of model config registry
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
|
||||||
""" enumerate available model architectures based on config files """
|
|
||||||
return list(_MODEL_CONFIGS.keys())
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_config(path):
|
|
||||||
""" add model config path or file and update registry """
|
|
||||||
if not isinstance(path, Path):
|
|
||||||
path = Path(path)
|
|
||||||
_MODEL_CONFIG_PATHS.append(path)
|
|
||||||
_rescan_model_configs()
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name):
|
|
||||||
if model_name in _MODEL_CONFIGS:
|
|
||||||
return deepcopy(_MODEL_CONFIGS[model_name])
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
|
||||||
if model_name.startswith(HF_HUB_PREFIX):
|
|
||||||
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
|
||||||
else:
|
|
||||||
config = get_model_config(model_name)
|
|
||||||
tokenizer = HFTokenizer(
|
|
||||||
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
||||||
state_dict = checkpoint['state_dict']
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
if next(iter(state_dict.items()))[0].startswith('module'):
|
|
||||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, strict=True):
|
|
||||||
state_dict = load_state_dict(checkpoint_path)
|
|
||||||
# detect old format and make compatible with new format
|
|
||||||
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
|
||||||
state_dict = convert_to_custom_text_state_dict(state_dict)
|
|
||||||
resize_pos_embed(state_dict, model)
|
|
||||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
|
||||||
return incompatible_keys
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(
|
|
||||||
model_name: str,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
precision: str = 'fp32',
|
|
||||||
device: Union[str, torch.device] = 'cpu',
|
|
||||||
jit: bool = False,
|
|
||||||
force_quick_gelu: bool = False,
|
|
||||||
force_custom_text: bool = False,
|
|
||||||
force_patch_dropout: Optional[float] = None,
|
|
||||||
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
pretrained_image: bool = False,
|
|
||||||
pretrained_hf: bool = True,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
output_dict: Optional[bool] = None,
|
|
||||||
require_pretrained: bool = False,
|
|
||||||
):
|
|
||||||
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
|
||||||
if has_hf_hub_prefix:
|
|
||||||
model_id = model_name[len(HF_HUB_PREFIX):]
|
|
||||||
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
|
||||||
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
pretrained_cfg = config['preprocess_cfg']
|
|
||||||
model_cfg = config['model_cfg']
|
|
||||||
else:
|
|
||||||
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
|
||||||
checkpoint_path = None
|
|
||||||
pretrained_cfg = {}
|
|
||||||
model_cfg = None
|
|
||||||
|
|
||||||
if isinstance(device, str):
|
|
||||||
device = torch.device(device)
|
|
||||||
|
|
||||||
if pretrained and pretrained.lower() == 'openai':
|
|
||||||
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
|
||||||
model = load_openai_model(
|
|
||||||
model_name,
|
|
||||||
precision=precision,
|
|
||||||
device=device,
|
|
||||||
jit=jit,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
# to always output dict even if it is clip
|
|
||||||
if output_dict and hasattr(model, "output_dict"):
|
|
||||||
model.output_dict = True
|
|
||||||
else:
|
|
||||||
model_cfg = model_cfg or get_model_config(model_name)
|
|
||||||
if model_cfg is not None:
|
|
||||||
logging.info(f'Loaded {model_name} model config.')
|
|
||||||
else:
|
|
||||||
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
|
||||||
raise RuntimeError(f'Model config for {model_name} not found.')
|
|
||||||
|
|
||||||
if force_quick_gelu:
|
|
||||||
# override for use of QuickGELU on non-OpenAI transformer models
|
|
||||||
model_cfg["quick_gelu"] = True
|
|
||||||
|
|
||||||
if force_patch_dropout is not None:
|
|
||||||
# override the default patch dropout value
|
|
||||||
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
|
||||||
|
|
||||||
if force_image_size is not None:
|
|
||||||
# override model config's image size
|
|
||||||
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
|
||||||
|
|
||||||
if pretrained_image:
|
|
||||||
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
|
||||||
# pretrained weight loading for timm models set via vision_cfg
|
|
||||||
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
|
||||||
else:
|
|
||||||
assert False, 'pretrained image towers currently only supported for timm models'
|
|
||||||
|
|
||||||
cast_dtype = get_cast_dtype(precision)
|
|
||||||
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
|
||||||
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
|
||||||
|
|
||||||
if custom_text:
|
|
||||||
if is_hf_model:
|
|
||||||
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
|
||||||
if "coca" in model_name:
|
|
||||||
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
|
||||||
else:
|
|
||||||
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
|
||||||
else:
|
|
||||||
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
|
||||||
|
|
||||||
pretrained_loaded = False
|
|
||||||
if pretrained:
|
|
||||||
checkpoint_path = ''
|
|
||||||
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
|
||||||
if pretrained_cfg:
|
|
||||||
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
|
||||||
elif os.path.exists(pretrained):
|
|
||||||
checkpoint_path = pretrained
|
|
||||||
|
|
||||||
if checkpoint_path:
|
|
||||||
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
|
||||||
load_checkpoint(model, checkpoint_path)
|
|
||||||
else:
|
|
||||||
error_str = (
|
|
||||||
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
|
||||||
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
|
||||||
logging.warning(error_str)
|
|
||||||
raise RuntimeError(error_str)
|
|
||||||
pretrained_loaded = True
|
|
||||||
elif has_hf_hub_prefix:
|
|
||||||
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
|
||||||
load_checkpoint(model, checkpoint_path)
|
|
||||||
pretrained_loaded = True
|
|
||||||
|
|
||||||
if require_pretrained and not pretrained_loaded:
|
|
||||||
# callers of create_model_from_pretrained always expect pretrained weights
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
|
||||||
|
|
||||||
model.to(device=device)
|
|
||||||
if precision in ("fp16", "bf16"):
|
|
||||||
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
|
|
||||||
|
|
||||||
# set image / mean metadata from pretrained_cfg if available, or use default
|
|
||||||
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
|
||||||
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
|
||||||
|
|
||||||
# to always output dict even if it is clip
|
|
||||||
if output_dict and hasattr(model, "output_dict"):
|
|
||||||
model.output_dict = True
|
|
||||||
|
|
||||||
if jit:
|
|
||||||
model = torch.jit.script(model)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def create_loss(args):
|
|
||||||
if args.distill:
|
|
||||||
return DistillClipLoss(
|
|
||||||
local_loss=args.local_loss,
|
|
||||||
gather_with_grad=args.gather_with_grad,
|
|
||||||
cache_labels=True,
|
|
||||||
rank=args.rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
use_horovod=args.horovod,
|
|
||||||
)
|
|
||||||
elif "coca" in args.model.lower():
|
|
||||||
return CoCaLoss(
|
|
||||||
caption_loss_weight=args.coca_caption_loss_weight,
|
|
||||||
clip_loss_weight=args.coca_contrastive_loss_weight,
|
|
||||||
local_loss=args.local_loss,
|
|
||||||
gather_with_grad=args.gather_with_grad,
|
|
||||||
cache_labels=True,
|
|
||||||
rank=args.rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
use_horovod=args.horovod,
|
|
||||||
)
|
|
||||||
return ClipLoss(
|
|
||||||
local_loss=args.local_loss,
|
|
||||||
gather_with_grad=args.gather_with_grad,
|
|
||||||
cache_labels=True,
|
|
||||||
rank=args.rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
use_horovod=args.horovod,
|
|
||||||
)
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
def __init__(self, input_size):
|
|
||||||
super().__init__()
|
|
||||||
self.input_size = input_size
|
|
||||||
self.layers = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(self.input_size, 1024),
|
|
||||||
torch.nn.Dropout(0.2),
|
|
||||||
torch.nn.Linear(1024, 128),
|
|
||||||
torch.nn.Dropout(0.2),
|
|
||||||
torch.nn.Linear(128, 64),
|
|
||||||
torch.nn.Dropout(0.1),
|
|
||||||
torch.nn.Linear(64, 16),
|
|
||||||
torch.nn.Linear(16, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
# class semantic_head(torch.nn.Module):
|
|
||||||
# def __init__(self, input_size):
|
|
||||||
# super().__init__()
|
|
||||||
# self.input_size = input_size # for ViT-L-14 is 1024
|
|
||||||
# self.seg_head = torch.nn.Sequential(
|
|
||||||
# torch.nn.Linear(input_size, 128),
|
|
||||||
# torch.nn.Dropout(0.2),
|
|
||||||
# torch.nn.Linear(128, 64),
|
|
||||||
# torch.nn.Dropout(0.1),
|
|
||||||
# torch.nn.Linear(64, 16),
|
|
||||||
# torch.nn.Linear(16, 1),
|
|
||||||
# )
|
|
||||||
# self.sigmoid = torch.nn.Sigmoid()
|
|
||||||
|
|
||||||
# def forward(self, x):
|
|
||||||
# return self.sigmoid(self.seg_head(x))
|
|
||||||
|
|
||||||
def create_model_and_transforms(
|
|
||||||
model_name: str,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
precision: str = 'fp32',
|
|
||||||
device: Union[str, torch.device] = 'cpu',
|
|
||||||
jit: bool = False,
|
|
||||||
force_quick_gelu: bool = False,
|
|
||||||
force_custom_text: bool = False,
|
|
||||||
force_patch_dropout: Optional[float] = None,
|
|
||||||
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
pretrained_image: bool = False,
|
|
||||||
pretrained_hf: bool = True,
|
|
||||||
image_mean: Optional[Tuple[float, ...]] = None,
|
|
||||||
image_std: Optional[Tuple[float, ...]] = None,
|
|
||||||
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
light_augmentation = False,
|
|
||||||
output_dict: Optional[bool] = None,
|
|
||||||
with_score_predictor: bool = False,
|
|
||||||
with_region_predictor: bool = False
|
|
||||||
):
|
|
||||||
model = create_model(
|
|
||||||
model_name,
|
|
||||||
pretrained,
|
|
||||||
precision=precision,
|
|
||||||
device=device,
|
|
||||||
jit=jit,
|
|
||||||
force_quick_gelu=force_quick_gelu,
|
|
||||||
force_custom_text=force_custom_text,
|
|
||||||
force_patch_dropout=force_patch_dropout,
|
|
||||||
force_image_size=force_image_size,
|
|
||||||
pretrained_image=pretrained_image,
|
|
||||||
pretrained_hf=pretrained_hf,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
output_dict=output_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
|
||||||
image_std = image_std or getattr(model.visual, 'image_std', None)
|
|
||||||
|
|
||||||
if with_score_predictor:
|
|
||||||
model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
|
||||||
|
|
||||||
if with_region_predictor:
|
|
||||||
# model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
|
||||||
model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
|
|
||||||
# preprocess_train = image_transform_region(
|
|
||||||
# model.visual.image_size,
|
|
||||||
# is_train=True,
|
|
||||||
# mean=image_mean,
|
|
||||||
# std=image_std
|
|
||||||
# )
|
|
||||||
# preprocess_val = image_transform_region(
|
|
||||||
# model.visual.image_size,
|
|
||||||
# is_train=False,
|
|
||||||
# mean=image_mean,
|
|
||||||
# std=image_std
|
|
||||||
# )
|
|
||||||
|
|
||||||
if light_augmentation:
|
|
||||||
preprocess_val = image_transform(
|
|
||||||
model.visual.image_size,
|
|
||||||
is_train=False,
|
|
||||||
mean=image_mean,
|
|
||||||
std=image_std,
|
|
||||||
resize_longest_max=True,
|
|
||||||
)
|
|
||||||
preprocess_train = preprocess_val
|
|
||||||
else:
|
|
||||||
preprocess_train = image_transform(
|
|
||||||
model.visual.image_size,
|
|
||||||
is_train=True,
|
|
||||||
mean=image_mean,
|
|
||||||
std=image_std
|
|
||||||
)
|
|
||||||
preprocess_val = image_transform(
|
|
||||||
model.visual.image_size,
|
|
||||||
is_train=False,
|
|
||||||
mean=image_mean,
|
|
||||||
std=image_std
|
|
||||||
)
|
|
||||||
|
|
||||||
return model, preprocess_train, preprocess_val
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_from_pretrained(
|
|
||||||
model_name: str,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
precision: str = 'fp32',
|
|
||||||
device: Union[str, torch.device] = 'cpu',
|
|
||||||
jit: bool = False,
|
|
||||||
force_quick_gelu: bool = False,
|
|
||||||
force_custom_text: bool = False,
|
|
||||||
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
return_transform: bool = True,
|
|
||||||
image_mean: Optional[Tuple[float, ...]] = None,
|
|
||||||
image_std: Optional[Tuple[float, ...]] = None,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
):
|
|
||||||
model = create_model(
|
|
||||||
model_name,
|
|
||||||
pretrained,
|
|
||||||
precision=precision,
|
|
||||||
device=device,
|
|
||||||
jit=jit,
|
|
||||||
force_quick_gelu=force_quick_gelu,
|
|
||||||
force_custom_text=force_custom_text,
|
|
||||||
force_image_size=force_image_size,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
require_pretrained=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_transform:
|
|
||||||
return model
|
|
||||||
|
|
||||||
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
|
||||||
image_std = image_std or getattr(model.visual, 'image_std', None)
|
|
||||||
preprocess = image_transform(
|
|
||||||
model.visual.image_size,
|
|
||||||
is_train=False,
|
|
||||||
mean=image_mean,
|
|
||||||
std=image_std,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model, preprocess
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
# HF architecture dict:
|
|
||||||
arch_dict = {
|
|
||||||
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
|
||||||
"roberta": {
|
|
||||||
"config_names": {
|
|
||||||
"context_length": "max_position_embeddings",
|
|
||||||
"vocab_size": "vocab_size",
|
|
||||||
"width": "hidden_size",
|
|
||||||
"heads": "num_attention_heads",
|
|
||||||
"layers": "num_hidden_layers",
|
|
||||||
"layer_attr": "layer",
|
|
||||||
"token_embeddings_attr": "embeddings"
|
|
||||||
},
|
|
||||||
"pooler": "mean_pooler",
|
|
||||||
},
|
|
||||||
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
|
||||||
"xlm-roberta": {
|
|
||||||
"config_names": {
|
|
||||||
"context_length": "max_position_embeddings",
|
|
||||||
"vocab_size": "vocab_size",
|
|
||||||
"width": "hidden_size",
|
|
||||||
"heads": "num_attention_heads",
|
|
||||||
"layers": "num_hidden_layers",
|
|
||||||
"layer_attr": "layer",
|
|
||||||
"token_embeddings_attr": "embeddings"
|
|
||||||
},
|
|
||||||
"pooler": "mean_pooler",
|
|
||||||
},
|
|
||||||
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
|
||||||
"mt5": {
|
|
||||||
"config_names": {
|
|
||||||
# unlimited seqlen
|
|
||||||
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
|
||||||
"context_length": "",
|
|
||||||
"vocab_size": "vocab_size",
|
|
||||||
"width": "d_model",
|
|
||||||
"heads": "num_heads",
|
|
||||||
"layers": "num_layers",
|
|
||||||
"layer_attr": "block",
|
|
||||||
"token_embeddings_attr": "embed_tokens"
|
|
||||||
},
|
|
||||||
"pooler": "mean_pooler",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
""" huggingface model adapter
|
|
||||||
|
|
||||||
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import TensorType
|
|
||||||
|
|
||||||
try:
|
|
||||||
import transformers
|
|
||||||
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
|
||||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions
|
|
||||||
except ImportError as e:
|
|
||||||
transformers = None
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelOutput:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PretrainedConfig:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from .hf_configs import arch_dict
|
|
||||||
|
|
||||||
|
|
||||||
# utils
|
|
||||||
def _camel2snake(s):
|
|
||||||
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: ?last - for gpt-like models
|
|
||||||
_POOLERS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_pooler(cls):
|
|
||||||
"""Decorator registering pooler class"""
|
|
||||||
_POOLERS[_camel2snake(cls.__name__)] = cls
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
@register_pooler
|
|
||||||
class MeanPooler(nn.Module):
|
|
||||||
"""Mean pooling"""
|
|
||||||
|
|
||||||
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
|
||||||
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
|
||||||
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
|
||||||
|
|
||||||
|
|
||||||
@register_pooler
|
|
||||||
class MaxPooler(nn.Module):
|
|
||||||
"""Max pooling"""
|
|
||||||
|
|
||||||
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
|
||||||
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
|
||||||
return masked_output.max(1).values
|
|
||||||
|
|
||||||
|
|
||||||
@register_pooler
|
|
||||||
class ClsPooler(nn.Module):
|
|
||||||
"""CLS token pooling"""
|
|
||||||
|
|
||||||
def __init__(self, use_pooler_output=True):
|
|
||||||
super().__init__()
|
|
||||||
self.cls_token_position = 0
|
|
||||||
self.use_pooler_output = use_pooler_output
|
|
||||||
|
|
||||||
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
|
||||||
if (self.use_pooler_output and
|
|
||||||
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
|
||||||
(x.pooler_output is not None)
|
|
||||||
):
|
|
||||||
return x.pooler_output
|
|
||||||
|
|
||||||
return x.last_hidden_state[:, self.cls_token_position, :]
|
|
||||||
|
|
||||||
|
|
||||||
class HFTextEncoder(nn.Module):
|
|
||||||
"""HuggingFace model adapter"""
|
|
||||||
output_tokens: torch.jit.Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name_or_path: str,
|
|
||||||
output_dim: int,
|
|
||||||
config: PretrainedConfig = None,
|
|
||||||
pooler_type: str = None,
|
|
||||||
proj: str = None,
|
|
||||||
pretrained: bool = True,
|
|
||||||
output_tokens: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.output_tokens = output_tokens
|
|
||||||
self.output_dim = output_dim
|
|
||||||
|
|
||||||
# TODO: find better way to get this information
|
|
||||||
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
|
||||||
|
|
||||||
if transformers is None:
|
|
||||||
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
|
||||||
if config is None:
|
|
||||||
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
|
||||||
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
|
||||||
AutoModel.from_config, self.config)
|
|
||||||
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
|
||||||
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
|
||||||
self.transformer = create_func(model_args)
|
|
||||||
self.transformer = self.transformer.encoder
|
|
||||||
else:
|
|
||||||
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
|
||||||
else:
|
|
||||||
self.config = config
|
|
||||||
self.transformer = AutoModel.from_config(config)
|
|
||||||
if pooler_type is None: # get default arch pooler
|
|
||||||
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
|
||||||
|
|
||||||
self.pooler = _POOLERS[pooler_type]()
|
|
||||||
|
|
||||||
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
|
||||||
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
|
||||||
self.proj = nn.Identity()
|
|
||||||
elif proj == 'linear':
|
|
||||||
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
|
||||||
elif proj == 'mlp':
|
|
||||||
hidden_size = (d_model + output_dim) // 2
|
|
||||||
self.proj = nn.Sequential(
|
|
||||||
nn.Linear(d_model, hidden_size, bias=False),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(hidden_size, output_dim, bias=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: TensorType):
|
|
||||||
attn_mask = (x != self.config.pad_token_id).long()
|
|
||||||
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
|
||||||
pooled_out = self.pooler(out, attn_mask)
|
|
||||||
projected = self.proj(pooled_out)
|
|
||||||
|
|
||||||
seq_len = out.last_hidden_state.shape[1]
|
|
||||||
tokens = (
|
|
||||||
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
|
||||||
if type(self.pooler) == ClsPooler
|
|
||||||
else out.last_hidden_state
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.output_tokens:
|
|
||||||
return projected, tokens
|
|
||||||
return projected
|
|
||||||
|
|
||||||
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
|
||||||
if not unlocked_layers: # full freezing
|
|
||||||
for n, p in self.transformer.named_parameters():
|
|
||||||
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
|
||||||
return
|
|
||||||
|
|
||||||
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
|
||||||
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
|
||||||
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
|
||||||
embeddings = getattr(
|
|
||||||
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
|
||||||
modules = [embeddings, *layer_list][:-unlocked_layers]
|
|
||||||
# freeze layers
|
|
||||||
for module in modules:
|
|
||||||
for n, p in module.named_parameters():
|
|
||||||
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.transformer.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
def init_parameters(self):
|
|
||||||
pass
|
|
||||||
@@ -1,270 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch.distributed.nn
|
|
||||||
from torch import distributed as dist
|
|
||||||
|
|
||||||
has_distributed = True
|
|
||||||
except ImportError:
|
|
||||||
has_distributed = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import horovod.torch as hvd
|
|
||||||
except ImportError:
|
|
||||||
hvd = None
|
|
||||||
|
|
||||||
|
|
||||||
def gather_features(
|
|
||||||
image_features,
|
|
||||||
text_features,
|
|
||||||
local_loss=False,
|
|
||||||
gather_with_grad=False,
|
|
||||||
rank=0,
|
|
||||||
world_size=1,
|
|
||||||
use_horovod=False
|
|
||||||
):
|
|
||||||
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
|
||||||
if use_horovod:
|
|
||||||
assert hvd is not None, 'Please install horovod'
|
|
||||||
if gather_with_grad:
|
|
||||||
all_image_features = hvd.allgather(image_features)
|
|
||||||
all_text_features = hvd.allgather(text_features)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
all_image_features = hvd.allgather(image_features)
|
|
||||||
all_text_features = hvd.allgather(text_features)
|
|
||||||
if not local_loss:
|
|
||||||
# ensure grads for local rank when all_* features don't have a gradient
|
|
||||||
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
|
||||||
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
|
||||||
gathered_image_features[rank] = image_features
|
|
||||||
gathered_text_features[rank] = text_features
|
|
||||||
all_image_features = torch.cat(gathered_image_features, dim=0)
|
|
||||||
all_text_features = torch.cat(gathered_text_features, dim=0)
|
|
||||||
else:
|
|
||||||
# We gather tensors from all gpus
|
|
||||||
if gather_with_grad:
|
|
||||||
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
|
||||||
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
|
||||||
else:
|
|
||||||
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
|
||||||
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
|
||||||
dist.all_gather(gathered_image_features, image_features)
|
|
||||||
dist.all_gather(gathered_text_features, text_features)
|
|
||||||
if not local_loss:
|
|
||||||
# ensure grads for local rank when all_* features don't have a gradient
|
|
||||||
gathered_image_features[rank] = image_features
|
|
||||||
gathered_text_features[rank] = text_features
|
|
||||||
all_image_features = torch.cat(gathered_image_features, dim=0)
|
|
||||||
all_text_features = torch.cat(gathered_text_features, dim=0)
|
|
||||||
|
|
||||||
return all_image_features, all_text_features
|
|
||||||
|
|
||||||
|
|
||||||
class ClipLoss(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
local_loss=False,
|
|
||||||
gather_with_grad=False,
|
|
||||||
cache_labels=False,
|
|
||||||
rank=0,
|
|
||||||
world_size=1,
|
|
||||||
use_horovod=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.local_loss = local_loss
|
|
||||||
self.gather_with_grad = gather_with_grad
|
|
||||||
self.cache_labels = cache_labels
|
|
||||||
self.rank = rank
|
|
||||||
self.world_size = world_size
|
|
||||||
self.use_horovod = use_horovod
|
|
||||||
|
|
||||||
# cache state
|
|
||||||
self.prev_num_logits = 0
|
|
||||||
self.labels = {}
|
|
||||||
|
|
||||||
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
|
||||||
# calculated ground-truth and cache if enabled
|
|
||||||
if self.prev_num_logits != num_logits or device not in self.labels:
|
|
||||||
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
|
||||||
if self.world_size > 1 and self.local_loss:
|
|
||||||
labels = labels + num_logits * self.rank
|
|
||||||
if self.cache_labels:
|
|
||||||
self.labels[device] = labels
|
|
||||||
self.prev_num_logits = num_logits
|
|
||||||
else:
|
|
||||||
labels = self.labels[device]
|
|
||||||
return labels
|
|
||||||
|
|
||||||
def get_logits(self, image_features, text_features, logit_scale):
|
|
||||||
if self.world_size > 1:
|
|
||||||
all_image_features, all_text_features = gather_features(
|
|
||||||
image_features, text_features,
|
|
||||||
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
|
||||||
|
|
||||||
if self.local_loss:
|
|
||||||
logits_per_image = logit_scale * image_features @ all_text_features.T
|
|
||||||
logits_per_text = logit_scale * text_features @ all_image_features.T
|
|
||||||
else:
|
|
||||||
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
|
||||||
logits_per_text = logits_per_image.T
|
|
||||||
else:
|
|
||||||
logits_per_image = logit_scale * image_features @ text_features.T
|
|
||||||
logits_per_text = logit_scale * text_features @ image_features.T
|
|
||||||
|
|
||||||
return logits_per_image, logits_per_text
|
|
||||||
|
|
||||||
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
|
||||||
device = image_features.device
|
|
||||||
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
|
||||||
|
|
||||||
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
|
||||||
|
|
||||||
total_loss = (
|
|
||||||
F.cross_entropy(logits_per_image, labels) +
|
|
||||||
F.cross_entropy(logits_per_text, labels)
|
|
||||||
) / 2
|
|
||||||
return total_loss
|
|
||||||
|
|
||||||
class PreferenceLoss(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, logits_per_image, num_images, labels):
|
|
||||||
|
|
||||||
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
|
||||||
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
|
|
||||||
|
|
||||||
ce_loss = F.cross_entropy(paired_logits, labels)
|
|
||||||
return ce_loss
|
|
||||||
|
|
||||||
class HPSLoss(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, text_logits, labels):
|
|
||||||
|
|
||||||
device = text_logits.device
|
|
||||||
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
|
|
||||||
label_0, label_1 = labels.chunk(2, dim=-1)
|
|
||||||
|
|
||||||
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
|
|
||||||
text_0_logits = text_0_logits[index, index]
|
|
||||||
text_1_logits = text_1_logits[index, index]
|
|
||||||
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
|
||||||
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
|
|
||||||
text_1_labels = text_0_labels + 1
|
|
||||||
|
|
||||||
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
|
|
||||||
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
|
|
||||||
|
|
||||||
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
|
|
||||||
|
|
||||||
# absolute_example_weight = 1 / num_per_prompt
|
|
||||||
# denominator = absolute_example_weight.sum()
|
|
||||||
# weight_per_example = absolute_example_weight / denominator
|
|
||||||
# text_loss *= weight_per_example
|
|
||||||
|
|
||||||
text_loss = text_loss.sum()
|
|
||||||
return text_loss
|
|
||||||
|
|
||||||
class RankingLoss(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, logits_per_image, num_images, labels, margin = 1.0):
|
|
||||||
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
|
||||||
label_list = [label for label in labels.split(num_images.tolist())]
|
|
||||||
# ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
|
|
||||||
|
|
||||||
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
|
|
||||||
padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
|
|
||||||
|
|
||||||
# regulized_logits = torch.log(torch.sigmoid(paired_logits))
|
|
||||||
|
|
||||||
diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
|
||||||
# diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
|
||||||
# diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
|
|
||||||
diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
|
|
||||||
mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
|
|
||||||
|
|
||||||
loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
class CoCaLoss(ClipLoss):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
caption_loss_weight,
|
|
||||||
clip_loss_weight,
|
|
||||||
pad_id=0, # pad_token for open_clip custom tokenizer
|
|
||||||
local_loss=False,
|
|
||||||
gather_with_grad=False,
|
|
||||||
cache_labels=False,
|
|
||||||
rank=0,
|
|
||||||
world_size=1,
|
|
||||||
use_horovod=False,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
local_loss=local_loss,
|
|
||||||
gather_with_grad=gather_with_grad,
|
|
||||||
cache_labels=cache_labels,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
use_horovod=use_horovod
|
|
||||||
)
|
|
||||||
|
|
||||||
self.clip_loss_weight = clip_loss_weight
|
|
||||||
self.caption_loss_weight = caption_loss_weight
|
|
||||||
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
|
||||||
|
|
||||||
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
|
||||||
clip_loss = super().forward(image_features, text_features, logit_scale)
|
|
||||||
clip_loss = self.clip_loss_weight * clip_loss
|
|
||||||
|
|
||||||
caption_loss = self.caption_loss(
|
|
||||||
logits.permute(0, 2, 1),
|
|
||||||
labels,
|
|
||||||
)
|
|
||||||
caption_loss = caption_loss * self.caption_loss_weight
|
|
||||||
|
|
||||||
if output_dict:
|
|
||||||
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
|
||||||
|
|
||||||
return clip_loss, caption_loss
|
|
||||||
|
|
||||||
|
|
||||||
class DistillClipLoss(ClipLoss):
|
|
||||||
|
|
||||||
def dist_loss(self, teacher_logits, student_logits):
|
|
||||||
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
image_features,
|
|
||||||
text_features,
|
|
||||||
logit_scale,
|
|
||||||
dist_image_features,
|
|
||||||
dist_text_features,
|
|
||||||
dist_logit_scale,
|
|
||||||
output_dict=False,
|
|
||||||
):
|
|
||||||
logits_per_image, logits_per_text = \
|
|
||||||
self.get_logits(image_features, text_features, logit_scale)
|
|
||||||
|
|
||||||
dist_logits_per_image, dist_logits_per_text = \
|
|
||||||
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
|
||||||
|
|
||||||
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
|
||||||
|
|
||||||
contrastive_loss = (
|
|
||||||
F.cross_entropy(logits_per_image, labels) +
|
|
||||||
F.cross_entropy(logits_per_text, labels)
|
|
||||||
) / 2
|
|
||||||
|
|
||||||
distill_loss = (
|
|
||||||
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
|
||||||
self.dist_loss(dist_logits_per_text, logits_per_text)
|
|
||||||
) / 2
|
|
||||||
|
|
||||||
if output_dict:
|
|
||||||
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
|
||||||
|
|
||||||
return contrastive_loss, distill_loss
|
|
||||||
@@ -1,461 +0,0 @@
|
|||||||
""" CLIP Model
|
|
||||||
|
|
||||||
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
|
||||||
"""
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
from .hf_model import HFTextEncoder
|
|
||||||
from .modified_resnet import ModifiedResNet
|
|
||||||
from .timm_model import TimmModel
|
|
||||||
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
|
||||||
from .utils import to_2tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CLIPVisionCfg:
|
|
||||||
layers: Union[Tuple[int, int, int, int], int] = 12
|
|
||||||
width: int = 768
|
|
||||||
head_width: int = 64
|
|
||||||
mlp_ratio: float = 4.0
|
|
||||||
patch_size: int = 16
|
|
||||||
image_size: Union[Tuple[int, int], int] = 224
|
|
||||||
ls_init_value: Optional[float] = None # layer scale initial value
|
|
||||||
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
|
||||||
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
|
||||||
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
|
||||||
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
|
||||||
n_queries: int = 256 # n_queries for attentional pooler
|
|
||||||
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
|
||||||
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
|
||||||
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
|
||||||
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
|
||||||
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
|
||||||
timm_proj_bias: bool = False # enable bias final projection
|
|
||||||
timm_drop: float = 0. # head dropout
|
|
||||||
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
|
||||||
output_tokens: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CLIPTextCfg:
|
|
||||||
context_length: int = 77
|
|
||||||
vocab_size: int = 49408
|
|
||||||
width: int = 512
|
|
||||||
heads: int = 8
|
|
||||||
layers: int = 12
|
|
||||||
ls_init_value: Optional[float] = None # layer scale initial value
|
|
||||||
hf_model_name: str = None
|
|
||||||
hf_tokenizer_name: str = None
|
|
||||||
hf_model_pretrained: bool = True
|
|
||||||
proj: str = 'mlp'
|
|
||||||
pooler_type: str = 'mean_pooler'
|
|
||||||
embed_cls: bool = False
|
|
||||||
pad_id: int = 0
|
|
||||||
output_tokens: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_cast_dtype(precision: str):
|
|
||||||
cast_dtype = None
|
|
||||||
if precision == 'bf16':
|
|
||||||
cast_dtype = torch.bfloat16
|
|
||||||
elif precision == 'fp16':
|
|
||||||
cast_dtype = torch.float16
|
|
||||||
return cast_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def _build_vision_tower(
|
|
||||||
embed_dim: int,
|
|
||||||
vision_cfg: CLIPVisionCfg,
|
|
||||||
quick_gelu: bool = False,
|
|
||||||
cast_dtype: Optional[torch.dtype] = None
|
|
||||||
):
|
|
||||||
if isinstance(vision_cfg, dict):
|
|
||||||
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
|
||||||
|
|
||||||
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
|
||||||
# memory efficient in recent PyTorch releases (>= 1.10).
|
|
||||||
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
|
||||||
act_layer = QuickGELU if quick_gelu else nn.GELU
|
|
||||||
|
|
||||||
if vision_cfg.timm_model_name:
|
|
||||||
visual = TimmModel(
|
|
||||||
vision_cfg.timm_model_name,
|
|
||||||
pretrained=vision_cfg.timm_model_pretrained,
|
|
||||||
pool=vision_cfg.timm_pool,
|
|
||||||
proj=vision_cfg.timm_proj,
|
|
||||||
proj_bias=vision_cfg.timm_proj_bias,
|
|
||||||
drop=vision_cfg.timm_drop,
|
|
||||||
drop_path=vision_cfg.timm_drop_path,
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
image_size=vision_cfg.image_size,
|
|
||||||
)
|
|
||||||
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
|
||||||
elif isinstance(vision_cfg.layers, (tuple, list)):
|
|
||||||
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
|
||||||
visual = ModifiedResNet(
|
|
||||||
layers=vision_cfg.layers,
|
|
||||||
output_dim=embed_dim,
|
|
||||||
heads=vision_heads,
|
|
||||||
image_size=vision_cfg.image_size,
|
|
||||||
width=vision_cfg.width,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vision_heads = vision_cfg.width // vision_cfg.head_width
|
|
||||||
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
|
||||||
visual = VisionTransformer(
|
|
||||||
image_size=vision_cfg.image_size,
|
|
||||||
patch_size=vision_cfg.patch_size,
|
|
||||||
width=vision_cfg.width,
|
|
||||||
layers=vision_cfg.layers,
|
|
||||||
heads=vision_heads,
|
|
||||||
mlp_ratio=vision_cfg.mlp_ratio,
|
|
||||||
ls_init_value=vision_cfg.ls_init_value,
|
|
||||||
patch_dropout=vision_cfg.patch_dropout,
|
|
||||||
input_patchnorm=vision_cfg.input_patchnorm,
|
|
||||||
global_average_pool=vision_cfg.global_average_pool,
|
|
||||||
attentional_pool=vision_cfg.attentional_pool,
|
|
||||||
n_queries=vision_cfg.n_queries,
|
|
||||||
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
|
||||||
output_tokens=vision_cfg.output_tokens,
|
|
||||||
output_dim=embed_dim,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return visual
|
|
||||||
|
|
||||||
|
|
||||||
def _build_text_tower(
|
|
||||||
embed_dim: int,
|
|
||||||
text_cfg: CLIPTextCfg,
|
|
||||||
quick_gelu: bool = False,
|
|
||||||
cast_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if isinstance(text_cfg, dict):
|
|
||||||
text_cfg = CLIPTextCfg(**text_cfg)
|
|
||||||
|
|
||||||
if text_cfg.hf_model_name:
|
|
||||||
text = HFTextEncoder(
|
|
||||||
text_cfg.hf_model_name,
|
|
||||||
output_dim=embed_dim,
|
|
||||||
proj=text_cfg.proj,
|
|
||||||
pooler_type=text_cfg.pooler_type,
|
|
||||||
pretrained=text_cfg.hf_model_pretrained,
|
|
||||||
output_tokens=text_cfg.output_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
act_layer = QuickGELU if quick_gelu else nn.GELU
|
|
||||||
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
|
||||||
|
|
||||||
text = TextTransformer(
|
|
||||||
context_length=text_cfg.context_length,
|
|
||||||
vocab_size=text_cfg.vocab_size,
|
|
||||||
width=text_cfg.width,
|
|
||||||
heads=text_cfg.heads,
|
|
||||||
layers=text_cfg.layers,
|
|
||||||
ls_init_value=text_cfg.ls_init_value,
|
|
||||||
output_dim=embed_dim,
|
|
||||||
embed_cls=text_cfg.embed_cls,
|
|
||||||
output_tokens=text_cfg.output_tokens,
|
|
||||||
pad_id=text_cfg.pad_id,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class CLIP(nn.Module):
|
|
||||||
output_dict: torch.jit.Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int,
|
|
||||||
vision_cfg: CLIPVisionCfg,
|
|
||||||
text_cfg: CLIPTextCfg,
|
|
||||||
quick_gelu: bool = False,
|
|
||||||
cast_dtype: Optional[torch.dtype] = None,
|
|
||||||
output_dict: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.output_dict = output_dict
|
|
||||||
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
|
||||||
|
|
||||||
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
|
||||||
self.transformer = text.transformer
|
|
||||||
self.vocab_size = text.vocab_size
|
|
||||||
self.token_embedding = text.token_embedding
|
|
||||||
self.positional_embedding = text.positional_embedding
|
|
||||||
self.ln_final = text.ln_final
|
|
||||||
self.text_projection = text.text_projection
|
|
||||||
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
|
||||||
|
|
||||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
||||||
|
|
||||||
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
|
||||||
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
|
||||||
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
|
||||||
|
|
||||||
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
|
||||||
locked_layers = []
|
|
||||||
locked_layers.append(self.token_embedding)
|
|
||||||
self.positional_embedding.requires_grad = False
|
|
||||||
if unlocked_layers > 0:
|
|
||||||
locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
|
|
||||||
else:
|
|
||||||
locked_layers.append(self.transformer)
|
|
||||||
locked_layers.append(self.ln_final)
|
|
||||||
self.text_projection.requires_grad = False
|
|
||||||
|
|
||||||
# freeze layers
|
|
||||||
for module in locked_layers:
|
|
||||||
for n, p in module.named_parameters():
|
|
||||||
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.visual.set_grad_checkpointing(enable)
|
|
||||||
self.transformer.grad_checkpointing = enable
|
|
||||||
|
|
||||||
def encode_image(self, image, normalize: bool = False):
|
|
||||||
features = self.visual(image)
|
|
||||||
return F.normalize(features, dim=-1) if normalize else features
|
|
||||||
|
|
||||||
def encode_text(self, text, normalize: bool = False):
|
|
||||||
cast_dtype = self.transformer.get_cast_dtype()
|
|
||||||
|
|
||||||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
|
||||||
|
|
||||||
x = x + self.positional_embedding.to(cast_dtype)
|
|
||||||
x = x.permute(1, 0, 2) # NLD -> LND
|
|
||||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
|
||||||
x = x.permute(1, 0, 2) # LND -> NLD
|
|
||||||
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
|
||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
||||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
|
||||||
return F.normalize(x, dim=-1) if normalize else x
|
|
||||||
|
|
||||||
def forward(self, image, text):
|
|
||||||
image_features = self.encode_image(image, normalize=True)
|
|
||||||
text_features = self.encode_text(text, normalize=True)
|
|
||||||
if self.output_dict:
|
|
||||||
return {
|
|
||||||
"image_features": image_features,
|
|
||||||
"text_features": text_features,
|
|
||||||
"logit_scale": self.logit_scale.exp()
|
|
||||||
}
|
|
||||||
return image_features, text_features, self.logit_scale.exp()
|
|
||||||
|
|
||||||
|
|
||||||
class CustomTextCLIP(nn.Module):
|
|
||||||
output_dict: torch.jit.Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int,
|
|
||||||
vision_cfg: CLIPVisionCfg,
|
|
||||||
text_cfg: CLIPTextCfg,
|
|
||||||
quick_gelu: bool = False,
|
|
||||||
cast_dtype: Optional[torch.dtype] = None,
|
|
||||||
output_dict: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.output_dict = output_dict
|
|
||||||
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
|
||||||
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
|
||||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
||||||
|
|
||||||
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
|
||||||
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
|
||||||
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
|
||||||
|
|
||||||
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
|
||||||
self.text.lock(unlocked_layers, freeze_layer_norm)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.visual.set_grad_checkpointing(enable)
|
|
||||||
self.text.set_grad_checkpointing(enable)
|
|
||||||
|
|
||||||
def encode_image(self, image, normalize: bool = False):
|
|
||||||
features = self.visual(image)
|
|
||||||
return F.normalize(features, dim=-1) if normalize else features
|
|
||||||
|
|
||||||
def encode_text(self, text, normalize: bool = False):
|
|
||||||
features = self.text(text)
|
|
||||||
return F.normalize(features, dim=-1) if normalize else features
|
|
||||||
|
|
||||||
def forward(self, image, text):
|
|
||||||
image_features = self.encode_image(image, normalize=True)
|
|
||||||
text_features = self.encode_text(text, normalize=True)
|
|
||||||
if self.output_dict:
|
|
||||||
return {
|
|
||||||
"image_features": image_features,
|
|
||||||
"text_features": text_features,
|
|
||||||
"logit_scale": self.logit_scale.exp()
|
|
||||||
}
|
|
||||||
return image_features, text_features, self.logit_scale.exp()
|
|
||||||
|
|
||||||
|
|
||||||
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
|
||||||
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
|
||||||
|
|
||||||
def _convert_weights(l):
|
|
||||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
|
||||||
l.weight.data = l.weight.data.to(dtype)
|
|
||||||
if l.bias is not None:
|
|
||||||
l.bias.data = l.bias.data.to(dtype)
|
|
||||||
|
|
||||||
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
|
||||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
|
||||||
tensor = getattr(l, attr)
|
|
||||||
if tensor is not None:
|
|
||||||
tensor.data = tensor.data.to(dtype)
|
|
||||||
|
|
||||||
for name in ["text_projection", "proj"]:
|
|
||||||
if hasattr(l, name):
|
|
||||||
attr = getattr(l, name)
|
|
||||||
if attr is not None:
|
|
||||||
attr.data = attr.data.to(dtype)
|
|
||||||
|
|
||||||
model.apply(_convert_weights)
|
|
||||||
|
|
||||||
|
|
||||||
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
|
||||||
|
|
||||||
|
|
||||||
# used to maintain checkpoint compatibility
|
|
||||||
def convert_to_custom_text_state_dict(state_dict: dict):
|
|
||||||
if 'text_projection' in state_dict:
|
|
||||||
# old format state_dict, move text tower -> .text
|
|
||||||
new_state_dict = {}
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
if any(k.startswith(p) for p in (
|
|
||||||
'text_projection',
|
|
||||||
'positional_embedding',
|
|
||||||
'token_embedding',
|
|
||||||
'transformer',
|
|
||||||
'ln_final',
|
|
||||||
)):
|
|
||||||
k = 'text.' + k
|
|
||||||
new_state_dict[k] = v
|
|
||||||
return new_state_dict
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def build_model_from_openai_state_dict(
|
|
||||||
state_dict: dict,
|
|
||||||
quick_gelu=True,
|
|
||||||
cast_dtype=torch.float16,
|
|
||||||
):
|
|
||||||
vit = "visual.proj" in state_dict
|
|
||||||
|
|
||||||
if vit:
|
|
||||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
|
||||||
vision_layers = len(
|
|
||||||
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
|
||||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
|
||||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
|
||||||
image_size = vision_patch_size * grid_size
|
|
||||||
else:
|
|
||||||
counts: list = [
|
|
||||||
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
|
||||||
vision_layers = tuple(counts)
|
|
||||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
|
||||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
|
||||||
vision_patch_size = None
|
|
||||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
|
||||||
image_size = output_width * 32
|
|
||||||
|
|
||||||
embed_dim = state_dict["text_projection"].shape[1]
|
|
||||||
context_length = state_dict["positional_embedding"].shape[0]
|
|
||||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
|
||||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
|
||||||
transformer_heads = transformer_width // 64
|
|
||||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
|
||||||
|
|
||||||
vision_cfg = CLIPVisionCfg(
|
|
||||||
layers=vision_layers,
|
|
||||||
width=vision_width,
|
|
||||||
patch_size=vision_patch_size,
|
|
||||||
image_size=image_size,
|
|
||||||
)
|
|
||||||
text_cfg = CLIPTextCfg(
|
|
||||||
context_length=context_length,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
width=transformer_width,
|
|
||||||
heads=transformer_heads,
|
|
||||||
layers=transformer_layers,
|
|
||||||
)
|
|
||||||
model = CLIP(
|
|
||||||
embed_dim,
|
|
||||||
vision_cfg=vision_cfg,
|
|
||||||
text_cfg=text_cfg,
|
|
||||||
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
|
||||||
cast_dtype=cast_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
|
||||||
state_dict.pop(key, None)
|
|
||||||
|
|
||||||
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
|
||||||
model.eval()
|
|
||||||
image_size = model.visual.image_size
|
|
||||||
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
|
||||||
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
|
||||||
model = torch.jit.trace_module(
|
|
||||||
model,
|
|
||||||
inputs=dict(
|
|
||||||
forward=(example_images, example_text),
|
|
||||||
encode_text=(example_text,),
|
|
||||||
encode_image=(example_images,)
|
|
||||||
))
|
|
||||||
model.visual.image_size = image_size
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
|
||||||
# Rescale the grid of position embeddings when loading from state_dict
|
|
||||||
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
|
||||||
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
|
||||||
return
|
|
||||||
grid_size = to_2tuple(model.visual.grid_size)
|
|
||||||
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
|
||||||
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
|
||||||
if new_seq_len == old_pos_embed.shape[0]:
|
|
||||||
return
|
|
||||||
|
|
||||||
if extra_tokens:
|
|
||||||
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
|
||||||
else:
|
|
||||||
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
|
||||||
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
|
||||||
|
|
||||||
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
|
||||||
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
|
||||||
pos_emb_img = F.interpolate(
|
|
||||||
pos_emb_img,
|
|
||||||
size=grid_size,
|
|
||||||
mode=interpolation,
|
|
||||||
antialias=antialias,
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
|
||||||
if pos_emb_tok is not None:
|
|
||||||
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
|
||||||
else:
|
|
||||||
new_pos_embed = pos_emb_img
|
|
||||||
state_dict['visual.positional_embedding'] = new_pos_embed
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 32,
|
|
||||||
"width": 1280,
|
|
||||||
"head_width": 80,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 24
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,181 +0,0 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .utils import freeze_batch_norm_2d
|
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
|
||||||
expansion = 4
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
|
||||||
self.act1 = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
|
||||||
self.act2 = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
|
||||||
|
|
||||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
|
||||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
|
||||||
self.act3 = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.downsample = None
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
|
||||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
|
||||||
self.downsample = nn.Sequential(OrderedDict([
|
|
||||||
("-1", nn.AvgPool2d(stride)),
|
|
||||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
|
||||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.act1(self.bn1(self.conv1(x)))
|
|
||||||
out = self.act2(self.bn2(self.conv2(out)))
|
|
||||||
out = self.avgpool(out)
|
|
||||||
out = self.bn3(self.conv3(out))
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.act3(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool2d(nn.Module):
|
|
||||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
|
||||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
|
||||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
|
||||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
|
||||||
x, _ = F.multi_head_attention_forward(
|
|
||||||
query=x, key=x, value=x,
|
|
||||||
embed_dim_to_check=x.shape[-1],
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
q_proj_weight=self.q_proj.weight,
|
|
||||||
k_proj_weight=self.k_proj.weight,
|
|
||||||
v_proj_weight=self.v_proj.weight,
|
|
||||||
in_proj_weight=None,
|
|
||||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
||||||
bias_k=None,
|
|
||||||
bias_v=None,
|
|
||||||
add_zero_attn=False,
|
|
||||||
dropout_p=0.,
|
|
||||||
out_proj_weight=self.c_proj.weight,
|
|
||||||
out_proj_bias=self.c_proj.bias,
|
|
||||||
use_separate_proj_weight=True,
|
|
||||||
training=self.training,
|
|
||||||
need_weights=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return x[0]
|
|
||||||
|
|
||||||
|
|
||||||
class ModifiedResNet(nn.Module):
|
|
||||||
"""
|
|
||||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
|
||||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
|
||||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
|
||||||
- The final pooling layer is a QKV attention instead of an average pool
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
|
||||||
super().__init__()
|
|
||||||
self.output_dim = output_dim
|
|
||||||
self.image_size = image_size
|
|
||||||
|
|
||||||
# the 3-layer stem
|
|
||||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
|
||||||
self.act1 = nn.ReLU(inplace=True)
|
|
||||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
|
||||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
|
||||||
self.act2 = nn.ReLU(inplace=True)
|
|
||||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
|
||||||
self.bn3 = nn.BatchNorm2d(width)
|
|
||||||
self.act3 = nn.ReLU(inplace=True)
|
|
||||||
self.avgpool = nn.AvgPool2d(2)
|
|
||||||
|
|
||||||
# residual layers
|
|
||||||
self._inplanes = width # this is a *mutable* variable used during construction
|
|
||||||
self.layer1 = self._make_layer(width, layers[0])
|
|
||||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
|
||||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
|
||||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
|
||||||
|
|
||||||
embed_dim = width * 32 # the ResNet feature dimension
|
|
||||||
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
|
||||||
|
|
||||||
self.init_parameters()
|
|
||||||
|
|
||||||
def _make_layer(self, planes, blocks, stride=1):
|
|
||||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
|
||||||
|
|
||||||
self._inplanes = planes * Bottleneck.expansion
|
|
||||||
for _ in range(1, blocks):
|
|
||||||
layers.append(Bottleneck(self._inplanes, planes))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def init_parameters(self):
|
|
||||||
if self.attnpool is not None:
|
|
||||||
std = self.attnpool.c_proj.in_features ** -0.5
|
|
||||||
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
|
||||||
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
|
||||||
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
|
||||||
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
|
||||||
|
|
||||||
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
|
||||||
for name, param in resnet_block.named_parameters():
|
|
||||||
if name.endswith("bn3.weight"):
|
|
||||||
nn.init.zeros_(param)
|
|
||||||
|
|
||||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
|
||||||
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
if freeze_bn_stats:
|
|
||||||
freeze_batch_norm_2d(self)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
# FIXME support for non-transformer
|
|
||||||
pass
|
|
||||||
|
|
||||||
def stem(self, x):
|
|
||||||
x = self.act1(self.bn1(self.conv1(x)))
|
|
||||||
x = self.act2(self.bn2(self.conv2(x)))
|
|
||||||
x = self.act3(self.bn3(self.conv3(x)))
|
|
||||||
x = self.avgpool(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.stem(x)
|
|
||||||
x = self.layer1(x)
|
|
||||||
x = self.layer2(x)
|
|
||||||
x = self.layer3(x)
|
|
||||||
x = self.layer4(x)
|
|
||||||
x = self.attnpool(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
""" OpenAI pretrained model functions
|
|
||||||
|
|
||||||
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
|
||||||
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
|
||||||
|
|
||||||
__all__ = ["list_openai_models", "load_openai_model"]
|
|
||||||
|
|
||||||
|
|
||||||
def list_openai_models() -> List[str]:
|
|
||||||
"""Returns the names of available CLIP models"""
|
|
||||||
return list_pretrained_models_by_tag('openai')
|
|
||||||
|
|
||||||
|
|
||||||
def load_openai_model(
|
|
||||||
name: str,
|
|
||||||
precision: Optional[str] = None,
|
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
|
||||||
jit: bool = True,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""Load a CLIP model
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
|
||||||
precision: str
|
|
||||||
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
|
||||||
device : Union[str, torch.device]
|
|
||||||
The device to put the loaded model
|
|
||||||
jit : bool
|
|
||||||
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
|
||||||
cache_dir : Optional[str]
|
|
||||||
The directory to cache the downloaded model weights
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
model : torch.nn.Module
|
|
||||||
The CLIP model
|
|
||||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
|
||||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
|
||||||
"""
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
if precision is None:
|
|
||||||
precision = 'fp32' if device == 'cpu' else 'fp16'
|
|
||||||
|
|
||||||
if get_pretrained_url(name, 'openai'):
|
|
||||||
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
|
||||||
elif os.path.isfile(name):
|
|
||||||
model_path = name
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# loading JIT archive
|
|
||||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
|
||||||
state_dict = None
|
|
||||||
except RuntimeError:
|
|
||||||
# loading saved state dict
|
|
||||||
if jit:
|
|
||||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
|
||||||
jit = False
|
|
||||||
state_dict = torch.load(model_path, map_location="cpu")
|
|
||||||
|
|
||||||
if not jit:
|
|
||||||
# Build a non-jit model from the OpenAI jitted model state dict
|
|
||||||
cast_dtype = get_cast_dtype(precision)
|
|
||||||
try:
|
|
||||||
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
|
||||||
except KeyError:
|
|
||||||
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
|
||||||
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
|
||||||
|
|
||||||
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
|
||||||
model = model.to(device)
|
|
||||||
if precision.startswith('amp') or precision == 'fp32':
|
|
||||||
model.float()
|
|
||||||
elif precision == 'bf16':
|
|
||||||
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
# patch the device names
|
|
||||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
|
||||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
|
||||||
|
|
||||||
def patch_device(module):
|
|
||||||
try:
|
|
||||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
|
||||||
except RuntimeError:
|
|
||||||
graphs = []
|
|
||||||
|
|
||||||
if hasattr(module, "forward1"):
|
|
||||||
graphs.append(module.forward1.graph)
|
|
||||||
|
|
||||||
for graph in graphs:
|
|
||||||
for node in graph.findAllNodes("prim::Constant"):
|
|
||||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
|
||||||
node.copyAttributes(device_node)
|
|
||||||
|
|
||||||
model.apply(patch_device)
|
|
||||||
patch_device(model.encode_image)
|
|
||||||
patch_device(model.encode_text)
|
|
||||||
|
|
||||||
# patch dtype to float32 (typically for CPU)
|
|
||||||
if precision == 'fp32':
|
|
||||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
|
||||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
|
||||||
float_node = float_input.node()
|
|
||||||
|
|
||||||
def patch_float(module):
|
|
||||||
try:
|
|
||||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
|
||||||
except RuntimeError:
|
|
||||||
graphs = []
|
|
||||||
|
|
||||||
if hasattr(module, "forward1"):
|
|
||||||
graphs.append(module.forward1.graph)
|
|
||||||
|
|
||||||
for graph in graphs:
|
|
||||||
for node in graph.findAllNodes("aten::to"):
|
|
||||||
inputs = list(node.inputs())
|
|
||||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
|
||||||
if inputs[i].node()["value"] == 5:
|
|
||||||
inputs[i].node().copyAttributes(float_node)
|
|
||||||
|
|
||||||
model.apply(patch_float)
|
|
||||||
patch_float(model.encode_image)
|
|
||||||
patch_float(model.encode_text)
|
|
||||||
model.float()
|
|
||||||
|
|
||||||
# ensure image_size attr available at consistent location for both jit and non-jit
|
|
||||||
model.visual.image_size = model.input_resolution.item()
|
|
||||||
return model
|
|
||||||
@@ -1,376 +0,0 @@
|
|||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import urllib
|
|
||||||
import warnings
|
|
||||||
from functools import partial
|
|
||||||
from typing import Dict, Union
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .version import __version__
|
|
||||||
|
|
||||||
try:
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
|
|
||||||
_has_hf_hub = True
|
|
||||||
except ImportError:
|
|
||||||
hf_hub_download = None
|
|
||||||
_has_hf_hub = False
|
|
||||||
|
|
||||||
|
|
||||||
def _pcfg(url='', hf_hub='', mean=None, std=None):
|
|
||||||
return dict(
|
|
||||||
url=url,
|
|
||||||
hf_hub=hf_hub,
|
|
||||||
mean=mean,
|
|
||||||
std=std,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_RN50 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
|
||||||
yfcc15m=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
|
||||||
cc12m=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_RN50_quickgelu = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
|
||||||
yfcc15m=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
|
||||||
cc12m=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_RN101 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
|
||||||
yfcc15m=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_RN101_quickgelu = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
|
||||||
yfcc15m=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_RN50x4 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_RN50x16 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_RN50x64 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITB32 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
|
||||||
laion400m_e31=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
|
||||||
laion400m_e32=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
|
||||||
laion2b_e16=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
|
||||||
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITB32_quickgelu = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
|
||||||
laion400m_e31=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
|
||||||
laion400m_e32=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITB16 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
|
||||||
laion400m_e31=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
|
||||||
laion400m_e32=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
|
||||||
# laion400m_32k=_pcfg(
|
|
||||||
# url="",
|
|
||||||
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
||||||
# laion400m_64k=_pcfg(
|
|
||||||
# url="",
|
|
||||||
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
||||||
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITB16_PLUS_240 = dict(
|
|
||||||
laion400m_e31=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
|
||||||
laion400m_e32=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITL14 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
|
||||||
laion400m_e31=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
|
||||||
laion400m_e32=_pcfg(
|
|
||||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
|
||||||
laion2b_s32b_b82k=_pcfg(
|
|
||||||
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITL14_336 = dict(
|
|
||||||
openai=_pcfg(
|
|
||||||
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITH14 = dict(
|
|
||||||
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITg14 = dict(
|
|
||||||
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
|
||||||
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_VITbigG14 = dict(
|
|
||||||
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_robertaViTB32 = dict(
|
|
||||||
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_xlmRobertaBaseViTB32 = dict(
|
|
||||||
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_xlmRobertaLargeFrozenViTH14 = dict(
|
|
||||||
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_convnext_base = dict(
|
|
||||||
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_convnext_base_w = dict(
|
|
||||||
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
|
|
||||||
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
|
|
||||||
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_convnext_base_w_320 = dict(
|
|
||||||
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
|
|
||||||
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_convnext_large_d = dict(
|
|
||||||
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_convnext_large_d_320 = dict(
|
|
||||||
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
|
|
||||||
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_convnext_xxlarge = dict(
|
|
||||||
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
|
|
||||||
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
|
|
||||||
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
|
|
||||||
)
|
|
||||||
|
|
||||||
_coca_VITB32 = dict(
|
|
||||||
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
|
|
||||||
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
|
|
||||||
)
|
|
||||||
|
|
||||||
_coca_VITL14 = dict(
|
|
||||||
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
|
|
||||||
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_PRETRAINED = {
|
|
||||||
"RN50": _RN50,
|
|
||||||
"RN50-quickgelu": _RN50_quickgelu,
|
|
||||||
"RN101": _RN101,
|
|
||||||
"RN101-quickgelu": _RN101_quickgelu,
|
|
||||||
"RN50x4": _RN50x4,
|
|
||||||
"RN50x16": _RN50x16,
|
|
||||||
"RN50x64": _RN50x64,
|
|
||||||
"ViT-B-32": _VITB32,
|
|
||||||
"ViT-B-32-quickgelu": _VITB32_quickgelu,
|
|
||||||
"ViT-B-16": _VITB16,
|
|
||||||
"ViT-B-16-plus-240": _VITB16_PLUS_240,
|
|
||||||
"ViT-L-14": _VITL14,
|
|
||||||
"ViT-L-14-336": _VITL14_336,
|
|
||||||
"ViT-H-14": _VITH14,
|
|
||||||
"ViT-g-14": _VITg14,
|
|
||||||
"ViT-bigG-14": _VITbigG14,
|
|
||||||
"roberta-ViT-B-32": _robertaViTB32,
|
|
||||||
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
|
|
||||||
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
|
|
||||||
"convnext_base": _convnext_base,
|
|
||||||
"convnext_base_w": _convnext_base_w,
|
|
||||||
"convnext_base_w_320": _convnext_base_w_320,
|
|
||||||
"convnext_large_d": _convnext_large_d,
|
|
||||||
"convnext_large_d_320": _convnext_large_d_320,
|
|
||||||
"convnext_xxlarge": _convnext_xxlarge,
|
|
||||||
"coca_ViT-B-32": _coca_VITB32,
|
|
||||||
"coca_ViT-L-14": _coca_VITL14,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _clean_tag(tag: str):
|
|
||||||
# normalize pretrained tags
|
|
||||||
return tag.lower().replace('-', '_')
|
|
||||||
|
|
||||||
|
|
||||||
def list_pretrained(as_str: bool = False):
|
|
||||||
""" returns list of pretrained models
|
|
||||||
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
|
||||||
"""
|
|
||||||
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
|
||||||
|
|
||||||
|
|
||||||
def list_pretrained_models_by_tag(tag: str):
|
|
||||||
""" return all models having the specified pretrain tag """
|
|
||||||
models = []
|
|
||||||
tag = _clean_tag(tag)
|
|
||||||
for k in _PRETRAINED.keys():
|
|
||||||
if tag in _PRETRAINED[k]:
|
|
||||||
models.append(k)
|
|
||||||
return models
|
|
||||||
|
|
||||||
|
|
||||||
def list_pretrained_tags_by_model(model: str):
|
|
||||||
""" return all pretrain tags for the specified model architecture """
|
|
||||||
tags = []
|
|
||||||
if model in _PRETRAINED:
|
|
||||||
tags.extend(_PRETRAINED[model].keys())
|
|
||||||
return tags
|
|
||||||
|
|
||||||
|
|
||||||
def is_pretrained_cfg(model: str, tag: str):
|
|
||||||
if model not in _PRETRAINED:
|
|
||||||
return False
|
|
||||||
return _clean_tag(tag) in _PRETRAINED[model]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg(model: str, tag: str):
|
|
||||||
if model not in _PRETRAINED:
|
|
||||||
return {}
|
|
||||||
model_pretrained = _PRETRAINED[model]
|
|
||||||
return model_pretrained.get(_clean_tag(tag), {})
|
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_url(model: str, tag: str):
|
|
||||||
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
|
||||||
return cfg.get('url', '')
|
|
||||||
|
|
||||||
|
|
||||||
def download_pretrained_from_url(
|
|
||||||
url: str,
|
|
||||||
cache_dir: Union[str, None] = None,
|
|
||||||
):
|
|
||||||
if not cache_dir:
|
|
||||||
cache_dir = os.path.expanduser("~/.cache/clip")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
filename = os.path.basename(url)
|
|
||||||
|
|
||||||
if 'openaipublic' in url:
|
|
||||||
expected_sha256 = url.split("/")[-2]
|
|
||||||
elif 'mlfoundations' in url:
|
|
||||||
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
|
||||||
else:
|
|
||||||
expected_sha256 = ''
|
|
||||||
|
|
||||||
download_target = os.path.join(cache_dir, filename)
|
|
||||||
|
|
||||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
|
||||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
|
||||||
|
|
||||||
if os.path.isfile(download_target):
|
|
||||||
if expected_sha256:
|
|
||||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
|
||||||
return download_target
|
|
||||||
else:
|
|
||||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
|
||||||
else:
|
|
||||||
return download_target
|
|
||||||
|
|
||||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
|
||||||
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
|
||||||
while True:
|
|
||||||
buffer = source.read(8192)
|
|
||||||
if not buffer:
|
|
||||||
break
|
|
||||||
|
|
||||||
output.write(buffer)
|
|
||||||
loop.update(len(buffer))
|
|
||||||
|
|
||||||
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
|
||||||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
|
||||||
|
|
||||||
return download_target
|
|
||||||
|
|
||||||
|
|
||||||
def has_hf_hub(necessary=False):
|
|
||||||
if not _has_hf_hub and necessary:
|
|
||||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
|
||||||
raise RuntimeError(
|
|
||||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
|
||||||
return _has_hf_hub
|
|
||||||
|
|
||||||
|
|
||||||
def download_pretrained_from_hf(
|
|
||||||
model_id: str,
|
|
||||||
filename: str = 'open_clip_pytorch_model.bin',
|
|
||||||
revision=None,
|
|
||||||
cache_dir: Union[str, None] = None,
|
|
||||||
):
|
|
||||||
has_hf_hub(True)
|
|
||||||
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
|
||||||
return cached_file
|
|
||||||
|
|
||||||
|
|
||||||
def download_pretrained(
|
|
||||||
cfg: Dict,
|
|
||||||
force_hf_hub: bool = False,
|
|
||||||
cache_dir: Union[str, None] = None,
|
|
||||||
):
|
|
||||||
target = ''
|
|
||||||
if not cfg:
|
|
||||||
return target
|
|
||||||
|
|
||||||
download_url = cfg.get('url', '')
|
|
||||||
download_hf_hub = cfg.get('hf_hub', '')
|
|
||||||
if download_hf_hub and force_hf_hub:
|
|
||||||
# use HF hub even if url exists
|
|
||||||
download_url = ''
|
|
||||||
|
|
||||||
if download_url:
|
|
||||||
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
|
||||||
elif download_hf_hub:
|
|
||||||
has_hf_hub(True)
|
|
||||||
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
|
||||||
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
|
||||||
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
|
||||||
model_id, filename = os.path.split(download_hf_hub)
|
|
||||||
if filename:
|
|
||||||
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
|
||||||
else:
|
|
||||||
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
|
||||||
|
|
||||||
return target
|
|
||||||
@@ -1,243 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from tempfile import TemporaryDirectory
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
from huggingface_hub import (
|
|
||||||
create_repo,
|
|
||||||
get_hf_file_metadata,
|
|
||||||
hf_hub_download,
|
|
||||||
hf_hub_url,
|
|
||||||
repo_type_and_id_from_hf_id,
|
|
||||||
upload_folder,
|
|
||||||
)
|
|
||||||
from huggingface_hub.utils import EntryNotFoundError
|
|
||||||
_has_hf_hub = True
|
|
||||||
except ImportError:
|
|
||||||
_has_hf_hub = False
|
|
||||||
|
|
||||||
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
|
|
||||||
from .tokenizer import HFTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def save_config_for_hf(
|
|
||||||
model,
|
|
||||||
config_path: str,
|
|
||||||
model_config: Optional[dict]
|
|
||||||
):
|
|
||||||
preprocess_cfg = {
|
|
||||||
'mean': model.visual.image_mean,
|
|
||||||
'std': model.visual.image_std,
|
|
||||||
}
|
|
||||||
hf_config = {
|
|
||||||
'model_cfg': model_config,
|
|
||||||
'preprocess_cfg': preprocess_cfg,
|
|
||||||
}
|
|
||||||
|
|
||||||
with config_path.open('w') as f:
|
|
||||||
json.dump(hf_config, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def save_for_hf(
|
|
||||||
model,
|
|
||||||
tokenizer: HFTokenizer,
|
|
||||||
model_config: dict,
|
|
||||||
save_directory: str,
|
|
||||||
weights_filename='open_clip_pytorch_model.bin',
|
|
||||||
config_filename='open_clip_config.json',
|
|
||||||
):
|
|
||||||
save_directory = Path(save_directory)
|
|
||||||
save_directory.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
weights_path = save_directory / weights_filename
|
|
||||||
torch.save(model.state_dict(), weights_path)
|
|
||||||
|
|
||||||
tokenizer.save_pretrained(save_directory)
|
|
||||||
|
|
||||||
config_path = save_directory / config_filename
|
|
||||||
save_config_for_hf(model, config_path, model_config=model_config)
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_hf_hub(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
model_config: Optional[dict],
|
|
||||||
repo_id: str,
|
|
||||||
commit_message: str = 'Add model',
|
|
||||||
token: Optional[str] = None,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
private: bool = False,
|
|
||||||
create_pr: bool = False,
|
|
||||||
model_card: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
if not isinstance(tokenizer, HFTokenizer):
|
|
||||||
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
|
|
||||||
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
|
|
||||||
|
|
||||||
# Create repo if it doesn't exist yet
|
|
||||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
|
||||||
|
|
||||||
# Infer complete repo_id from repo_url
|
|
||||||
# Can be different from the input `repo_id` if repo_owner was implicit
|
|
||||||
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
|
||||||
repo_id = f"{repo_owner}/{repo_name}"
|
|
||||||
|
|
||||||
# Check if README file already exist in repo
|
|
||||||
try:
|
|
||||||
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
|
||||||
has_readme = True
|
|
||||||
except EntryNotFoundError:
|
|
||||||
has_readme = False
|
|
||||||
|
|
||||||
# Dump model and push to Hub
|
|
||||||
with TemporaryDirectory() as tmpdir:
|
|
||||||
# Save model weights and config.
|
|
||||||
save_for_hf(
|
|
||||||
model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
model_config=model_config,
|
|
||||||
save_directory=tmpdir,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add readme if it does not exist
|
|
||||||
if not has_readme:
|
|
||||||
model_card = model_card or {}
|
|
||||||
model_name = repo_id.split('/')[-1]
|
|
||||||
readme_path = Path(tmpdir) / "README.md"
|
|
||||||
readme_text = generate_readme(model_card, model_name)
|
|
||||||
readme_path.write_text(readme_text)
|
|
||||||
|
|
||||||
# Upload model and return
|
|
||||||
return upload_folder(
|
|
||||||
repo_id=repo_id,
|
|
||||||
folder_path=tmpdir,
|
|
||||||
revision=revision,
|
|
||||||
create_pr=create_pr,
|
|
||||||
commit_message=commit_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def push_pretrained_to_hf_hub(
|
|
||||||
model_name,
|
|
||||||
pretrained: str,
|
|
||||||
repo_id: str,
|
|
||||||
image_mean: Optional[Tuple[float, ...]] = None,
|
|
||||||
image_std: Optional[Tuple[float, ...]] = None,
|
|
||||||
commit_message: str = 'Add model',
|
|
||||||
token: Optional[str] = None,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
private: bool = False,
|
|
||||||
create_pr: bool = False,
|
|
||||||
model_card: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
model, preprocess_eval = create_model_from_pretrained(
|
|
||||||
model_name,
|
|
||||||
pretrained=pretrained,
|
|
||||||
image_mean=image_mean,
|
|
||||||
image_std=image_std,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = get_model_config(model_name)
|
|
||||||
assert model_config
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model_name)
|
|
||||||
|
|
||||||
push_to_hf_hub(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
model_config=model_config,
|
|
||||||
repo_id=repo_id,
|
|
||||||
commit_message=commit_message,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
private=private,
|
|
||||||
create_pr=create_pr,
|
|
||||||
model_card=model_card,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_readme(model_card: dict, model_name: str):
|
|
||||||
readme_text = "---\n"
|
|
||||||
readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
|
|
||||||
readme_text += "library_tag: open_clip\n"
|
|
||||||
readme_text += f"license: {model_card.get('license', 'mit')}\n"
|
|
||||||
if 'details' in model_card and 'Dataset' in model_card['details']:
|
|
||||||
readme_text += 'datasets:\n'
|
|
||||||
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
|
||||||
readme_text += "---\n"
|
|
||||||
readme_text += f"# Model card for {model_name}\n"
|
|
||||||
if 'description' in model_card:
|
|
||||||
readme_text += f"\n{model_card['description']}\n"
|
|
||||||
if 'details' in model_card:
|
|
||||||
readme_text += f"\n## Model Details\n"
|
|
||||||
for k, v in model_card['details'].items():
|
|
||||||
if isinstance(v, (list, tuple)):
|
|
||||||
readme_text += f"- **{k}:**\n"
|
|
||||||
for vi in v:
|
|
||||||
readme_text += f" - {vi}\n"
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
readme_text += f"- **{k}:**\n"
|
|
||||||
for ki, vi in v.items():
|
|
||||||
readme_text += f" - {ki}: {vi}\n"
|
|
||||||
else:
|
|
||||||
readme_text += f"- **{k}:** {v}\n"
|
|
||||||
if 'usage' in model_card:
|
|
||||||
readme_text += f"\n## Model Usage\n"
|
|
||||||
readme_text += model_card['usage']
|
|
||||||
readme_text += '\n'
|
|
||||||
|
|
||||||
if 'comparison' in model_card:
|
|
||||||
readme_text += f"\n## Model Comparison\n"
|
|
||||||
readme_text += model_card['comparison']
|
|
||||||
readme_text += '\n'
|
|
||||||
|
|
||||||
if 'citation' in model_card:
|
|
||||||
readme_text += f"\n## Citation\n"
|
|
||||||
if not isinstance(model_card['citation'], (list, tuple)):
|
|
||||||
citations = [model_card['citation']]
|
|
||||||
else:
|
|
||||||
citations = model_card['citation']
|
|
||||||
for c in citations:
|
|
||||||
readme_text += f"```bibtex\n{c}\n```\n"
|
|
||||||
|
|
||||||
return readme_text
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", type=str, help="Name of the model to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained", type=str,
|
|
||||||
help="Use a pretrained CLIP model weights with the specified tag or file path.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id", type=str,
|
|
||||||
help="Destination HF Hub repo-id ie 'organization/model_id'.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
||||||
help='Override default image mean value of dataset')
|
|
||||||
parser.add_argument(
|
|
||||||
'--image-std', type=float, nargs='+', default=None, metavar='STD',
|
|
||||||
help='Override default image std deviation of of dataset')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
|
|
||||||
|
|
||||||
# FIXME add support to pass model_card json / template from file via cmd line
|
|
||||||
|
|
||||||
push_pretrained_to_hf_hub(
|
|
||||||
args.model,
|
|
||||||
args.pretrained,
|
|
||||||
args.repo_id,
|
|
||||||
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
|
|
||||||
image_std=args.image_std,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f'{args.model} saved.')
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
""" timm model adapter
|
|
||||||
|
|
||||||
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
try:
|
|
||||||
import timm
|
|
||||||
from timm.models.layers import Mlp, to_2tuple
|
|
||||||
try:
|
|
||||||
# old timm imports < 0.8.1
|
|
||||||
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
|
||||||
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
|
||||||
except ImportError:
|
|
||||||
# new timm imports >= 0.8.1
|
|
||||||
from timm.layers import RotAttentionPool2d
|
|
||||||
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
|
||||||
except ImportError:
|
|
||||||
timm = None
|
|
||||||
|
|
||||||
from .utils import freeze_batch_norm_2d
|
|
||||||
|
|
||||||
|
|
||||||
class TimmModel(nn.Module):
|
|
||||||
""" timm model adapter
|
|
||||||
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name,
|
|
||||||
embed_dim,
|
|
||||||
image_size=224,
|
|
||||||
pool='avg',
|
|
||||||
proj='linear',
|
|
||||||
proj_bias=False,
|
|
||||||
drop=0.,
|
|
||||||
drop_path=None,
|
|
||||||
pretrained=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if timm is None:
|
|
||||||
raise RuntimeError("Please `pip install timm` to use timm models.")
|
|
||||||
|
|
||||||
self.image_size = to_2tuple(image_size)
|
|
||||||
timm_kwargs = {}
|
|
||||||
if drop_path is not None:
|
|
||||||
timm_kwargs['drop_path_rate'] = drop_path
|
|
||||||
self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
|
|
||||||
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
|
||||||
feature_ndim = 1 if not feat_size else 2
|
|
||||||
if pool in ('abs_attn', 'rot_attn'):
|
|
||||||
assert feature_ndim == 2
|
|
||||||
# if attn pooling used, remove both classifier and default pool
|
|
||||||
self.trunk.reset_classifier(0, global_pool='')
|
|
||||||
else:
|
|
||||||
# reset global pool if pool config set, otherwise leave as network default
|
|
||||||
reset_kwargs = dict(global_pool=pool) if pool else {}
|
|
||||||
self.trunk.reset_classifier(0, **reset_kwargs)
|
|
||||||
prev_chs = self.trunk.num_features
|
|
||||||
|
|
||||||
head_layers = OrderedDict()
|
|
||||||
if pool == 'abs_attn':
|
|
||||||
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
|
||||||
prev_chs = embed_dim
|
|
||||||
elif pool == 'rot_attn':
|
|
||||||
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
|
||||||
prev_chs = embed_dim
|
|
||||||
else:
|
|
||||||
assert proj, 'projection layer needed if non-attention pooling is used.'
|
|
||||||
|
|
||||||
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
|
||||||
if proj == 'linear':
|
|
||||||
head_layers['drop'] = nn.Dropout(drop)
|
|
||||||
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
|
||||||
elif proj == 'mlp':
|
|
||||||
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
|
|
||||||
|
|
||||||
self.head = nn.Sequential(head_layers)
|
|
||||||
|
|
||||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
|
||||||
""" lock modules
|
|
||||||
Args:
|
|
||||||
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
|
||||||
"""
|
|
||||||
if not unlocked_groups:
|
|
||||||
# lock full model
|
|
||||||
for param in self.trunk.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
if freeze_bn_stats:
|
|
||||||
freeze_batch_norm_2d(self.trunk)
|
|
||||||
else:
|
|
||||||
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
|
||||||
try:
|
|
||||||
# FIXME import here until API stable and in an official release
|
|
||||||
from timm.models.helpers import group_parameters, group_modules
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
|
||||||
matcher = self.trunk.group_matcher()
|
|
||||||
gparams = group_parameters(self.trunk, matcher)
|
|
||||||
max_layer_id = max(gparams.keys())
|
|
||||||
max_layer_id = max_layer_id - unlocked_groups
|
|
||||||
for group_idx in range(max_layer_id + 1):
|
|
||||||
group = gparams[group_idx]
|
|
||||||
for param in group:
|
|
||||||
self.trunk.get_parameter(param).requires_grad = False
|
|
||||||
if freeze_bn_stats:
|
|
||||||
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
|
||||||
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
|
||||||
freeze_batch_norm_2d(self.trunk, gmodules)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
try:
|
|
||||||
self.trunk.set_grad_checkpointing(enable)
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.trunk(x)
|
|
||||||
x = self.head(x)
|
|
||||||
return x
|
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
""" CLIP tokenizer
|
|
||||||
|
|
||||||
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
|
||||||
"""
|
|
||||||
import gzip
|
|
||||||
import html
|
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Union, List
|
|
||||||
|
|
||||||
import ftfy
|
|
||||||
import regex as re
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# https://stackoverflow.com/q/62691279
|
|
||||||
import os
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def default_bpe():
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
|
||||||
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
|
||||||
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def bytes_to_unicode():
|
|
||||||
"""
|
|
||||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
|
||||||
The reversible bpe codes work on unicode strings.
|
|
||||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
|
||||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
|
||||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
|
||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
|
||||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
|
||||||
"""
|
|
||||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
|
||||||
cs = bs[:]
|
|
||||||
n = 0
|
|
||||||
for b in range(2**8):
|
|
||||||
if b not in bs:
|
|
||||||
bs.append(b)
|
|
||||||
cs.append(2**8+n)
|
|
||||||
n += 1
|
|
||||||
cs = [chr(n) for n in cs]
|
|
||||||
return dict(zip(bs, cs))
|
|
||||||
|
|
||||||
|
|
||||||
def get_pairs(word):
|
|
||||||
"""Return set of symbol pairs in a word.
|
|
||||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
|
||||||
"""
|
|
||||||
pairs = set()
|
|
||||||
prev_char = word[0]
|
|
||||||
for char in word[1:]:
|
|
||||||
pairs.add((prev_char, char))
|
|
||||||
prev_char = char
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
|
|
||||||
def basic_clean(text):
|
|
||||||
text = ftfy.fix_text(text)
|
|
||||||
text = html.unescape(html.unescape(text))
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def whitespace_clean(text):
|
|
||||||
text = re.sub(r'\s+', ' ', text)
|
|
||||||
text = text.strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleTokenizer(object):
|
|
||||||
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
|
||||||
self.byte_encoder = bytes_to_unicode()
|
|
||||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
||||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
|
||||||
merges = merges[1:49152-256-2+1]
|
|
||||||
merges = [tuple(merge.split()) for merge in merges]
|
|
||||||
vocab = list(bytes_to_unicode().values())
|
|
||||||
vocab = vocab + [v+'</w>' for v in vocab]
|
|
||||||
for merge in merges:
|
|
||||||
vocab.append(''.join(merge))
|
|
||||||
if not special_tokens:
|
|
||||||
special_tokens = ['<start_of_text>', '<end_of_text>']
|
|
||||||
else:
|
|
||||||
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
|
||||||
vocab.extend(special_tokens)
|
|
||||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
|
||||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
|
||||||
self.cache = {t:t for t in special_tokens}
|
|
||||||
special = "|".join(special_tokens)
|
|
||||||
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
|
||||||
|
|
||||||
self.vocab_size = len(self.encoder)
|
|
||||||
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
|
||||||
|
|
||||||
def bpe(self, token):
|
|
||||||
if token in self.cache:
|
|
||||||
return self.cache[token]
|
|
||||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
|
||||||
pairs = get_pairs(word)
|
|
||||||
|
|
||||||
if not pairs:
|
|
||||||
return token+'</w>'
|
|
||||||
|
|
||||||
while True:
|
|
||||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
||||||
if bigram not in self.bpe_ranks:
|
|
||||||
break
|
|
||||||
first, second = bigram
|
|
||||||
new_word = []
|
|
||||||
i = 0
|
|
||||||
while i < len(word):
|
|
||||||
try:
|
|
||||||
j = word.index(first, i)
|
|
||||||
new_word.extend(word[i:j])
|
|
||||||
i = j
|
|
||||||
except:
|
|
||||||
new_word.extend(word[i:])
|
|
||||||
break
|
|
||||||
|
|
||||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
|
||||||
new_word.append(first+second)
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
new_word.append(word[i])
|
|
||||||
i += 1
|
|
||||||
new_word = tuple(new_word)
|
|
||||||
word = new_word
|
|
||||||
if len(word) == 1:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
pairs = get_pairs(word)
|
|
||||||
word = ' '.join(word)
|
|
||||||
self.cache[token] = word
|
|
||||||
return word
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
bpe_tokens = []
|
|
||||||
text = whitespace_clean(basic_clean(text)).lower()
|
|
||||||
for token in re.findall(self.pat, text):
|
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
|
||||||
return bpe_tokens
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
|
||||||
return text
|
|
||||||
|
|
||||||
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
|
||||||
"""
|
|
||||||
Returns the tokenized representation of given input string(s)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
texts : Union[str, List[str]]
|
|
||||||
An input string or a list of input strings to tokenize
|
|
||||||
context_length : int
|
|
||||||
The context length to use; all CLIP models use 77 as the context length
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
|
||||||
"""
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
|
|
||||||
sot_token = self.encoder["<start_of_text>"]
|
|
||||||
eot_token = self.encoder["<end_of_text>"]
|
|
||||||
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
|
||||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
|
||||||
|
|
||||||
for i, tokens in enumerate(all_tokens):
|
|
||||||
if len(tokens) > context_length:
|
|
||||||
tokens = tokens[:context_length] # Truncate
|
|
||||||
tokens[-1] = eot_token
|
|
||||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HFTokenizer:
|
|
||||||
"""HuggingFace tokenizer wrapper"""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer_name: str):
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
||||||
|
|
||||||
def save_pretrained(self, dest):
|
|
||||||
self.tokenizer.save_pretrained(dest)
|
|
||||||
|
|
||||||
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
|
||||||
# same cleaning as for default tokenizer, except lowercasing
|
|
||||||
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
|
||||||
input_ids = self.tokenizer(
|
|
||||||
texts,
|
|
||||||
return_tensors='pt',
|
|
||||||
max_length=context_length,
|
|
||||||
padding='max_length',
|
|
||||||
truncation=True,
|
|
||||||
).input_ids
|
|
||||||
return input_ids
|
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
import warnings
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torchvision.transforms.functional as F
|
|
||||||
from functools import partial
|
|
||||||
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
|
||||||
CenterCrop
|
|
||||||
|
|
||||||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AugmentationCfg:
|
|
||||||
scale: Tuple[float, float] = (0.9, 1.0)
|
|
||||||
ratio: Optional[Tuple[float, float]] = None
|
|
||||||
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
|
||||||
interpolation: Optional[str] = None
|
|
||||||
re_prob: Optional[float] = None
|
|
||||||
re_count: Optional[int] = None
|
|
||||||
use_timm: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeMaxSize(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
|
||||||
super().__init__()
|
|
||||||
if not isinstance(max_size, int):
|
|
||||||
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
|
||||||
self.max_size = max_size
|
|
||||||
self.interpolation = interpolation
|
|
||||||
self.fn = min if fn == 'min' else min
|
|
||||||
self.fill = fill
|
|
||||||
|
|
||||||
def forward(self, img):
|
|
||||||
if isinstance(img, torch.Tensor):
|
|
||||||
height, width = img.shape[1:]
|
|
||||||
else:
|
|
||||||
width, height = img.size
|
|
||||||
scale = self.max_size / float(max(height, width))
|
|
||||||
if scale != 1.0:
|
|
||||||
new_size = tuple(round(dim * scale) for dim in (height, width))
|
|
||||||
img = F.resize(img, new_size, self.interpolation)
|
|
||||||
pad_h = self.max_size - new_size[0]
|
|
||||||
pad_w = self.max_size - new_size[1]
|
|
||||||
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_rgb_or_rgba(image):
|
|
||||||
if image.mode == 'RGBA':
|
|
||||||
return image
|
|
||||||
else:
|
|
||||||
return image.convert('RGB')
|
|
||||||
|
|
||||||
# def transform_and_split(merged, transform_fn, normalize_fn):
|
|
||||||
# transformed = transform_fn(merged)
|
|
||||||
# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
|
|
||||||
|
|
||||||
# # crop_img = _convert_to_rgb(crop_img)
|
|
||||||
# crop_img = normalize_fn(ToTensor()(crop_img))
|
|
||||||
# return crop_img, crop_label
|
|
||||||
|
|
||||||
class MaskAwareNormalize(nn.Module):
|
|
||||||
def __init__(self, mean, std):
|
|
||||||
super().__init__()
|
|
||||||
self.normalize = Normalize(mean=mean, std=std)
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
if tensor.shape[0] == 4:
|
|
||||||
return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
|
|
||||||
else:
|
|
||||||
return self.normalize(tensor)
|
|
||||||
|
|
||||||
def image_transform(
|
|
||||||
image_size: int,
|
|
||||||
is_train: bool,
|
|
||||||
mean: Optional[Tuple[float, ...]] = None,
|
|
||||||
std: Optional[Tuple[float, ...]] = None,
|
|
||||||
resize_longest_max: bool = False,
|
|
||||||
fill_color: int = 0,
|
|
||||||
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
|
||||||
):
|
|
||||||
mean = mean or OPENAI_DATASET_MEAN
|
|
||||||
if not isinstance(mean, (list, tuple)):
|
|
||||||
mean = (mean,) * 3
|
|
||||||
|
|
||||||
std = std or OPENAI_DATASET_STD
|
|
||||||
if not isinstance(std, (list, tuple)):
|
|
||||||
std = (std,) * 3
|
|
||||||
|
|
||||||
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
|
||||||
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
|
||||||
image_size = image_size[0]
|
|
||||||
|
|
||||||
if isinstance(aug_cfg, dict):
|
|
||||||
aug_cfg = AugmentationCfg(**aug_cfg)
|
|
||||||
else:
|
|
||||||
aug_cfg = aug_cfg or AugmentationCfg()
|
|
||||||
normalize = MaskAwareNormalize(mean=mean, std=std)
|
|
||||||
if is_train:
|
|
||||||
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
|
||||||
use_timm = aug_cfg_dict.pop('use_timm', False)
|
|
||||||
if use_timm:
|
|
||||||
assert False, "not tested for augmentation with mask"
|
|
||||||
from timm.data import create_transform # timm can still be optional
|
|
||||||
if isinstance(image_size, (tuple, list)):
|
|
||||||
assert len(image_size) >= 2
|
|
||||||
input_size = (3,) + image_size[-2:]
|
|
||||||
else:
|
|
||||||
input_size = (3, image_size, image_size)
|
|
||||||
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
|
||||||
aug_cfg_dict.setdefault('interpolation', 'random')
|
|
||||||
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
|
||||||
train_transform = create_transform(
|
|
||||||
input_size=input_size,
|
|
||||||
is_training=True,
|
|
||||||
hflip=0.,
|
|
||||||
mean=mean,
|
|
||||||
std=std,
|
|
||||||
re_mode='pixel',
|
|
||||||
**aug_cfg_dict,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
train_transform = Compose([
|
|
||||||
_convert_to_rgb_or_rgba,
|
|
||||||
ToTensor(),
|
|
||||||
RandomResizedCrop(
|
|
||||||
image_size,
|
|
||||||
scale=aug_cfg_dict.pop('scale'),
|
|
||||||
interpolation=InterpolationMode.BICUBIC,
|
|
||||||
),
|
|
||||||
normalize,
|
|
||||||
])
|
|
||||||
if aug_cfg_dict:
|
|
||||||
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
|
||||||
return train_transform
|
|
||||||
else:
|
|
||||||
transforms = [
|
|
||||||
_convert_to_rgb_or_rgba,
|
|
||||||
ToTensor(),
|
|
||||||
]
|
|
||||||
if resize_longest_max:
|
|
||||||
transforms.extend([
|
|
||||||
ResizeMaxSize(image_size, fill=fill_color)
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
transforms.extend([
|
|
||||||
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
|
||||||
CenterCrop(image_size),
|
|
||||||
])
|
|
||||||
transforms.extend([
|
|
||||||
normalize,
|
|
||||||
])
|
|
||||||
return Compose(transforms)
|
|
||||||
|
|
||||||
|
|
||||||
# def image_transform_region(
|
|
||||||
# image_size: int,
|
|
||||||
# is_train: bool,
|
|
||||||
# mean: Optional[Tuple[float, ...]] = None,
|
|
||||||
# std: Optional[Tuple[float, ...]] = None,
|
|
||||||
# resize_longest_max: bool = False,
|
|
||||||
# fill_color: int = 0,
|
|
||||||
# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
|
||||||
# ):
|
|
||||||
# mean = mean or OPENAI_DATASET_MEAN
|
|
||||||
# if not isinstance(mean, (list, tuple)):
|
|
||||||
# mean = (mean,) * 3
|
|
||||||
|
|
||||||
# std = std or OPENAI_DATASET_STD
|
|
||||||
# if not isinstance(std, (list, tuple)):
|
|
||||||
# std = (std,) * 3
|
|
||||||
|
|
||||||
# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
|
||||||
# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
|
||||||
# image_size = image_size[0]
|
|
||||||
|
|
||||||
# if isinstance(aug_cfg, dict):
|
|
||||||
# aug_cfg = AugmentationCfg(**aug_cfg)
|
|
||||||
# else:
|
|
||||||
# aug_cfg = aug_cfg or AugmentationCfg()
|
|
||||||
# normalize = Normalize(mean=mean, std=std)
|
|
||||||
# if is_train:
|
|
||||||
# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
|
||||||
|
|
||||||
# transform = Compose([
|
|
||||||
# RandomResizedCrop(
|
|
||||||
# image_size,
|
|
||||||
# scale=aug_cfg_dict.pop('scale'),
|
|
||||||
# interpolation=InterpolationMode.BICUBIC,
|
|
||||||
# ),
|
|
||||||
# ])
|
|
||||||
# train_transform = Compose([
|
|
||||||
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
|
|
||||||
# ])
|
|
||||||
# return train_transform
|
|
||||||
# else:
|
|
||||||
# if resize_longest_max:
|
|
||||||
# transform = [
|
|
||||||
# ResizeMaxSize(image_size, fill=fill_color)
|
|
||||||
# ]
|
|
||||||
# val_transform = Compose([
|
|
||||||
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
|
||||||
# ])
|
|
||||||
# else:
|
|
||||||
# transform = [
|
|
||||||
# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
|
||||||
# CenterCrop(image_size),
|
|
||||||
# ]
|
|
||||||
# val_transform = Compose([
|
|
||||||
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
|
||||||
# ])
|
|
||||||
# return val_transform
|
|
||||||
@@ -1,727 +0,0 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
import math
|
|
||||||
from typing import Callable, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
from .utils import to_2tuple
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormFp32(nn.LayerNorm):
|
|
||||||
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
orig_type = x.dtype
|
|
||||||
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
|
||||||
return x.to(orig_type)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
orig_type = x.dtype
|
|
||||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
||||||
return x.to(orig_type)
|
|
||||||
|
|
||||||
|
|
||||||
class QuickGELU(nn.Module):
|
|
||||||
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return x * torch.sigmoid(1.702 * x)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerScale(nn.Module):
|
|
||||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
|
||||||
super().__init__()
|
|
||||||
self.inplace = inplace
|
|
||||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
||||||
|
|
||||||
|
|
||||||
class PatchDropout(nn.Module):
|
|
||||||
"""
|
|
||||||
https://arxiv.org/abs/2212.00794
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prob, exclude_first_token=True):
|
|
||||||
super().__init__()
|
|
||||||
assert 0 <= prob < 1.
|
|
||||||
self.prob = prob
|
|
||||||
self.exclude_first_token = exclude_first_token # exclude CLS token
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if not self.training or self.prob == 0.:
|
|
||||||
return x
|
|
||||||
|
|
||||||
if self.exclude_first_token:
|
|
||||||
cls_tokens, x = x[:, :1], x[:, 1:]
|
|
||||||
else:
|
|
||||||
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
|
||||||
|
|
||||||
batch = x.size()[0]
|
|
||||||
num_tokens = x.size()[1]
|
|
||||||
|
|
||||||
batch_indices = torch.arange(batch)
|
|
||||||
batch_indices = batch_indices[..., None]
|
|
||||||
|
|
||||||
keep_prob = 1 - self.prob
|
|
||||||
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
|
||||||
|
|
||||||
rand = torch.randn(batch, num_tokens)
|
|
||||||
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
|
||||||
|
|
||||||
x = x[batch_indices, patch_indices_keep]
|
|
||||||
|
|
||||||
if self.exclude_first_token:
|
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
num_heads=8,
|
|
||||||
qkv_bias=True,
|
|
||||||
scaled_cosine=False,
|
|
||||||
scale_heads=False,
|
|
||||||
logit_scale_max=math.log(1. / 0.01),
|
|
||||||
attn_drop=0.,
|
|
||||||
proj_drop=0.
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.scaled_cosine = scaled_cosine
|
|
||||||
self.scale_heads = scale_heads
|
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
self.scale = self.head_dim ** -0.5
|
|
||||||
self.logit_scale_max = logit_scale_max
|
|
||||||
|
|
||||||
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
|
||||||
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
|
||||||
if qkv_bias:
|
|
||||||
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
|
||||||
else:
|
|
||||||
self.in_proj_bias = None
|
|
||||||
|
|
||||||
if self.scaled_cosine:
|
|
||||||
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
|
||||||
else:
|
|
||||||
self.logit_scale = None
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
if self.scale_heads:
|
|
||||||
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
|
||||||
else:
|
|
||||||
self.head_scale = None
|
|
||||||
self.out_proj = nn.Linear(dim, dim)
|
|
||||||
self.out_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
|
||||||
L, N, C = x.shape
|
|
||||||
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
|
||||||
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
|
||||||
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
|
||||||
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
|
||||||
|
|
||||||
if self.logit_scale is not None:
|
|
||||||
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
|
||||||
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
|
||||||
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
|
||||||
attn = attn.view(-1, L, L)
|
|
||||||
else:
|
|
||||||
q = q * self.scale
|
|
||||||
attn = torch.bmm(q, k.transpose(-1, -2))
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
if attn_mask.dtype == torch.bool:
|
|
||||||
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
|
||||||
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
|
||||||
attn_mask = new_attn_mask
|
|
||||||
attn += attn_mask
|
|
||||||
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
|
|
||||||
x = torch.bmm(attn, v)
|
|
||||||
if self.head_scale is not None:
|
|
||||||
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
|
||||||
x = x.view(-1, L, C)
|
|
||||||
x = x.transpose(0, 1).reshape(L, N, C)
|
|
||||||
x = self.out_proj(x)
|
|
||||||
x = self.out_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionalPooler(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
context_dim: int,
|
|
||||||
n_head: int = 8,
|
|
||||||
n_queries: int = 256,
|
|
||||||
norm_layer: Callable = LayerNorm
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
|
||||||
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
|
|
||||||
self.ln_q = norm_layer(d_model)
|
|
||||||
self.ln_k = norm_layer(context_dim)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
|
||||||
N = x.shape[1]
|
|
||||||
q = self.ln_q(self.query)
|
|
||||||
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
|
|
||||||
return out.permute(1, 0, 2) # LND -> NLD
|
|
||||||
|
|
||||||
def _repeat(self, query, N: int):
|
|
||||||
return query.unsqueeze(1).repeat(1, N, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_head: int,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
ls_init_value: float = None,
|
|
||||||
act_layer: Callable = nn.GELU,
|
|
||||||
norm_layer: Callable = LayerNorm,
|
|
||||||
is_cross_attention: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.ln_1 = norm_layer(d_model)
|
|
||||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
|
||||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
|
||||||
if is_cross_attention:
|
|
||||||
self.ln_1_kv = norm_layer(d_model)
|
|
||||||
|
|
||||||
self.ln_2 = norm_layer(d_model)
|
|
||||||
mlp_width = int(d_model * mlp_ratio)
|
|
||||||
self.mlp = nn.Sequential(OrderedDict([
|
|
||||||
("c_fc", nn.Linear(d_model, mlp_width)),
|
|
||||||
("gelu", act_layer()),
|
|
||||||
("c_proj", nn.Linear(mlp_width, d_model))
|
|
||||||
]))
|
|
||||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
self,
|
|
||||||
q_x: torch.Tensor,
|
|
||||||
k_x: Optional[torch.Tensor] = None,
|
|
||||||
v_x: Optional[torch.Tensor] = None,
|
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
k_x = k_x if k_x is not None else q_x
|
|
||||||
v_x = v_x if v_x is not None else q_x
|
|
||||||
|
|
||||||
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
|
||||||
return self.attn(
|
|
||||||
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
q_x: torch.Tensor,
|
|
||||||
k_x: Optional[torch.Tensor] = None,
|
|
||||||
v_x: Optional[torch.Tensor] = None,
|
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
|
||||||
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
|
||||||
|
|
||||||
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
|
||||||
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class CustomResidualAttentionBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_head: int,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
ls_init_value: float = None,
|
|
||||||
act_layer: Callable = nn.GELU,
|
|
||||||
norm_layer: Callable = LayerNorm,
|
|
||||||
scale_cosine_attn: bool = False,
|
|
||||||
scale_heads: bool = False,
|
|
||||||
scale_attn: bool = False,
|
|
||||||
scale_fc: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.ln_1 = norm_layer(d_model)
|
|
||||||
self.attn = Attention(
|
|
||||||
d_model, n_head,
|
|
||||||
scaled_cosine=scale_cosine_attn,
|
|
||||||
scale_heads=scale_heads,
|
|
||||||
)
|
|
||||||
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
|
||||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
|
||||||
|
|
||||||
self.ln_2 = norm_layer(d_model)
|
|
||||||
mlp_width = int(d_model * mlp_ratio)
|
|
||||||
self.mlp = nn.Sequential(OrderedDict([
|
|
||||||
("c_fc", nn.Linear(d_model, mlp_width)),
|
|
||||||
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
|
||||||
("gelu", act_layer()),
|
|
||||||
("c_proj", nn.Linear(mlp_width, d_model))
|
|
||||||
]))
|
|
||||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
|
||||||
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
|
||||||
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
width: int,
|
|
||||||
layers: int,
|
|
||||||
heads: int,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
ls_init_value: float = None,
|
|
||||||
act_layer: Callable = nn.GELU,
|
|
||||||
norm_layer: Callable = LayerNorm,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.width = width
|
|
||||||
self.layers = layers
|
|
||||||
self.grad_checkpointing = False
|
|
||||||
|
|
||||||
self.resblocks = nn.ModuleList([
|
|
||||||
ResidualAttentionBlock(
|
|
||||||
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
|
|
||||||
for _ in range(layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
def get_cast_dtype(self) -> torch.dtype:
|
|
||||||
return self.resblocks[0].mlp.c_fc.weight.dtype
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
|
||||||
for r in self.resblocks:
|
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
|
||||||
x = checkpoint(r, x, None, None, attn_mask)
|
|
||||||
else:
|
|
||||||
x = r(x, attn_mask=attn_mask)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
|
||||||
output_tokens: torch.jit.Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_size: int,
|
|
||||||
patch_size: int,
|
|
||||||
width: int,
|
|
||||||
layers: int,
|
|
||||||
heads: int,
|
|
||||||
mlp_ratio: float,
|
|
||||||
ls_init_value: float = None,
|
|
||||||
global_average_pool: bool = False,
|
|
||||||
attentional_pool: bool = False,
|
|
||||||
n_queries: int = 256,
|
|
||||||
attn_pooler_heads: int = 8,
|
|
||||||
output_dim: int = 512,
|
|
||||||
patch_dropout: float = 0.,
|
|
||||||
input_patchnorm: bool = False,
|
|
||||||
act_layer: Callable = nn.GELU,
|
|
||||||
norm_layer: Callable = LayerNorm,
|
|
||||||
output_tokens: bool = False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.output_tokens = output_tokens
|
|
||||||
image_height, image_width = self.image_size = to_2tuple(image_size)
|
|
||||||
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
|
||||||
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
|
||||||
self.output_dim = output_dim
|
|
||||||
|
|
||||||
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
|
|
||||||
self.input_patchnorm = input_patchnorm
|
|
||||||
|
|
||||||
if input_patchnorm:
|
|
||||||
patch_input_dim = patch_height * patch_width * 3
|
|
||||||
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
|
|
||||||
self.conv1 = nn.Linear(patch_input_dim, width)
|
|
||||||
else:
|
|
||||||
self.patchnorm_pre_ln = nn.Identity()
|
|
||||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
|
||||||
|
|
||||||
# class embeddings and positional embeddings
|
|
||||||
scale = width ** -0.5
|
|
||||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
|
||||||
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
|
||||||
|
|
||||||
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
|
||||||
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
|
||||||
|
|
||||||
self.ln_pre = norm_layer(width)
|
|
||||||
self.transformer = Transformer(
|
|
||||||
width,
|
|
||||||
layers,
|
|
||||||
heads,
|
|
||||||
mlp_ratio,
|
|
||||||
ls_init_value=ls_init_value,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.global_average_pool = global_average_pool
|
|
||||||
if attentional_pool:
|
|
||||||
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
|
|
||||||
self.ln_post = norm_layer(output_dim)
|
|
||||||
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
|
|
||||||
else:
|
|
||||||
self.attn_pool = None
|
|
||||||
self.ln_post = norm_layer(width)
|
|
||||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
|
||||||
|
|
||||||
self.init_parameters()
|
|
||||||
|
|
||||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
if unlocked_groups != 0:
|
|
||||||
groups = [
|
|
||||||
[
|
|
||||||
self.conv1,
|
|
||||||
self.class_embedding,
|
|
||||||
self.positional_embedding,
|
|
||||||
self.ln_pre,
|
|
||||||
],
|
|
||||||
*self.transformer.resblocks[:-1],
|
|
||||||
[
|
|
||||||
self.transformer.resblocks[-1],
|
|
||||||
self.ln_post,
|
|
||||||
],
|
|
||||||
self.proj,
|
|
||||||
]
|
|
||||||
|
|
||||||
def _unlock(x):
|
|
||||||
if isinstance(x, Sequence):
|
|
||||||
for g in x:
|
|
||||||
_unlock(g)
|
|
||||||
else:
|
|
||||||
if isinstance(x, torch.nn.Parameter):
|
|
||||||
x.requires_grad = True
|
|
||||||
else:
|
|
||||||
for p in x.parameters():
|
|
||||||
p.requires_grad = True
|
|
||||||
|
|
||||||
_unlock(groups[-unlocked_groups:])
|
|
||||||
|
|
||||||
def init_parameters(self):
|
|
||||||
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
|
|
||||||
# TODO experiment if default PyTorch init, below, or alternate init is best.
|
|
||||||
|
|
||||||
# nn.init.normal_(self.class_embedding, std=self.scale)
|
|
||||||
# nn.init.normal_(self.positional_embedding, std=self.scale)
|
|
||||||
#
|
|
||||||
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
|
||||||
# attn_std = self.transformer.width ** -0.5
|
|
||||||
# fc_std = (2 * self.transformer.width) ** -0.5
|
|
||||||
# for block in self.transformer.resblocks:
|
|
||||||
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
|
||||||
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
|
||||||
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
|
||||||
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
|
||||||
#
|
|
||||||
# if self.text_projection is not None:
|
|
||||||
# nn.init.normal_(self.text_projection, std=self.scale)
|
|
||||||
pass
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.transformer.grad_checkpointing = enable
|
|
||||||
|
|
||||||
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if self.global_average_pool:
|
|
||||||
return x.mean(dim=1), x
|
|
||||||
else:
|
|
||||||
return x[:, 0], x[:, 1:]
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, skip_pool: bool = False):
|
|
||||||
|
|
||||||
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
|
||||||
if self.input_patchnorm:
|
|
||||||
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
|
||||||
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
|
|
||||||
x = x.permute(0, 2, 4, 1, 3, 5)
|
|
||||||
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
|
|
||||||
x = self.patchnorm_pre_ln(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
else:
|
|
||||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
|
||||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
|
||||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
|
||||||
|
|
||||||
# class embeddings and positional embeddings
|
|
||||||
x = torch.cat(
|
|
||||||
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
|
||||||
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
|
||||||
x = x + self.positional_embedding.to(x.dtype)
|
|
||||||
|
|
||||||
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
|
||||||
x = self.patch_dropout(x)
|
|
||||||
x = self.ln_pre(x)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # NLD -> LND
|
|
||||||
x = self.transformer(x)
|
|
||||||
x = x.permute(1, 0, 2) # LND -> NLD
|
|
||||||
|
|
||||||
if skip_pool:
|
|
||||||
return x
|
|
||||||
|
|
||||||
if self.attn_pool is not None:
|
|
||||||
x = self.attn_pool(x)
|
|
||||||
x = self.ln_post(x)
|
|
||||||
pooled, tokens = self._global_pool(x)
|
|
||||||
else:
|
|
||||||
pooled, tokens = self._global_pool(x)
|
|
||||||
pooled = self.ln_post(pooled)
|
|
||||||
|
|
||||||
if self.proj is not None:
|
|
||||||
pooled = pooled @ self.proj
|
|
||||||
|
|
||||||
if self.output_tokens:
|
|
||||||
return pooled, tokens
|
|
||||||
|
|
||||||
return pooled
|
|
||||||
|
|
||||||
|
|
||||||
class TextTransformer(nn.Module):
|
|
||||||
output_tokens: torch.jit.Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
context_length: int = 77,
|
|
||||||
vocab_size: int = 49408,
|
|
||||||
width: int = 512,
|
|
||||||
heads: int = 8,
|
|
||||||
layers: int = 12,
|
|
||||||
ls_init_value: float = None,
|
|
||||||
output_dim: int = 512,
|
|
||||||
act_layer: Callable = nn.GELU,
|
|
||||||
norm_layer: Callable = LayerNorm,
|
|
||||||
embed_cls: bool = False,
|
|
||||||
pad_id: int = 0,
|
|
||||||
output_tokens: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.output_tokens = output_tokens
|
|
||||||
self.num_pos = self.context_length = context_length
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.width = width
|
|
||||||
self.output_dim = output_dim
|
|
||||||
self.heads = heads
|
|
||||||
self.pad_id = pad_id
|
|
||||||
|
|
||||||
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
|
||||||
|
|
||||||
if embed_cls:
|
|
||||||
self.cls_emb = nn.Parameter(torch.empty(width))
|
|
||||||
self.num_pos += 1
|
|
||||||
else:
|
|
||||||
self.cls_emb = None
|
|
||||||
|
|
||||||
self.token_embedding = nn.Embedding(vocab_size, width)
|
|
||||||
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
|
||||||
self.transformer = Transformer(
|
|
||||||
width=width,
|
|
||||||
layers=layers,
|
|
||||||
heads=heads,
|
|
||||||
ls_init_value=ls_init_value,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
)
|
|
||||||
self.ln_final = norm_layer(width)
|
|
||||||
|
|
||||||
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
|
||||||
|
|
||||||
self.init_parameters()
|
|
||||||
|
|
||||||
def init_parameters(self):
|
|
||||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
|
||||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
|
||||||
if self.cls_emb is not None:
|
|
||||||
nn.init.normal_(self.cls_emb, std=0.01)
|
|
||||||
|
|
||||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
|
||||||
attn_std = self.transformer.width ** -0.5
|
|
||||||
fc_std = (2 * self.transformer.width) ** -0.5
|
|
||||||
for block in self.transformer.resblocks:
|
|
||||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
|
||||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
|
||||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
|
||||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
|
||||||
|
|
||||||
if self.text_projection is not None:
|
|
||||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.transformer.grad_checkpointing = enable
|
|
||||||
|
|
||||||
def build_attention_mask(self):
|
|
||||||
# lazily create causal attention mask, with full attention between the tokens
|
|
||||||
# pytorch uses additive attention mask; fill with -inf
|
|
||||||
mask = torch.empty(self.num_pos, self.num_pos)
|
|
||||||
mask.fill_(float("-inf"))
|
|
||||||
mask.triu_(1) # zero out the lower diagonal
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def build_cls_mask(self, text, cast_dtype: torch.dtype):
|
|
||||||
cls_mask = (text != self.pad_id).unsqueeze(1)
|
|
||||||
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
|
|
||||||
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
|
|
||||||
additive_mask.fill_(0)
|
|
||||||
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
|
||||||
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
|
||||||
return additive_mask
|
|
||||||
|
|
||||||
def _repeat(self, t, N: int):
|
|
||||||
return t.reshape(1, 1, -1).repeat(N, 1, 1)
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
cast_dtype = self.transformer.get_cast_dtype()
|
|
||||||
seq_len = text.shape[1]
|
|
||||||
|
|
||||||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
|
||||||
attn_mask = self.attn_mask
|
|
||||||
if self.cls_emb is not None:
|
|
||||||
seq_len += 1
|
|
||||||
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
|
|
||||||
cls_mask = self.build_cls_mask(text, cast_dtype)
|
|
||||||
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
|
|
||||||
|
|
||||||
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
|
||||||
x = x.permute(1, 0, 2) # NLD -> LND
|
|
||||||
x = self.transformer(x, attn_mask=attn_mask)
|
|
||||||
x = x.permute(1, 0, 2) # LND -> NLD
|
|
||||||
|
|
||||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
|
||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
||||||
if self.cls_emb is not None:
|
|
||||||
pooled, tokens = x[:, -1], x[:, :-1]
|
|
||||||
pooled = self.ln_final(pooled)
|
|
||||||
else:
|
|
||||||
x = self.ln_final(x)
|
|
||||||
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
|
||||||
|
|
||||||
if self.text_projection is not None:
|
|
||||||
pooled = pooled @ self.text_projection
|
|
||||||
|
|
||||||
if self.output_tokens:
|
|
||||||
return pooled, tokens
|
|
||||||
|
|
||||||
return pooled
|
|
||||||
|
|
||||||
|
|
||||||
class MultimodalTransformer(Transformer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
width: int,
|
|
||||||
layers: int,
|
|
||||||
heads: int,
|
|
||||||
context_length: int = 77,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
ls_init_value: float = None,
|
|
||||||
act_layer: Callable = nn.GELU,
|
|
||||||
norm_layer: Callable = LayerNorm,
|
|
||||||
output_dim: int = 512,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
width=width,
|
|
||||||
layers=layers,
|
|
||||||
heads=heads,
|
|
||||||
mlp_ratio=mlp_ratio,
|
|
||||||
ls_init_value=ls_init_value,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
)
|
|
||||||
self.context_length = context_length
|
|
||||||
self.cross_attn = nn.ModuleList([
|
|
||||||
ResidualAttentionBlock(
|
|
||||||
width,
|
|
||||||
heads,
|
|
||||||
mlp_ratio,
|
|
||||||
ls_init_value=ls_init_value,
|
|
||||||
act_layer=act_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
is_cross_attention=True,
|
|
||||||
)
|
|
||||||
for _ in range(layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
|
||||||
|
|
||||||
self.ln_final = norm_layer(width)
|
|
||||||
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
|
||||||
|
|
||||||
def init_parameters(self):
|
|
||||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
|
||||||
attn_std = self.transformer.width ** -0.5
|
|
||||||
fc_std = (2 * self.transformer.width) ** -0.5
|
|
||||||
for block in self.transformer.resblocks:
|
|
||||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
|
||||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
|
||||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
|
||||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
|
||||||
for block in self.transformer.cross_attn:
|
|
||||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
|
||||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
|
||||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
|
||||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
|
||||||
|
|
||||||
if self.text_projection is not None:
|
|
||||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
|
||||||
|
|
||||||
def build_attention_mask(self):
|
|
||||||
# lazily create causal attention mask, with full attention between the tokens
|
|
||||||
# pytorch uses additive attention mask; fill with -inf
|
|
||||||
mask = torch.empty(self.context_length, self.context_length)
|
|
||||||
mask.fill_(float("-inf"))
|
|
||||||
mask.triu_(1) # zero out the lower diagonal
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def forward(self, image_embs, text_embs):
|
|
||||||
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
|
|
||||||
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
|
|
||||||
seq_len = text_embs.shape[0]
|
|
||||||
|
|
||||||
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
|
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
|
||||||
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
|
|
||||||
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
|
|
||||||
else:
|
|
||||||
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
|
|
||||||
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
|
|
||||||
|
|
||||||
x = text_embs.permute(1, 0, 2) # LND -> NLD
|
|
||||||
x = self.ln_final(x)
|
|
||||||
|
|
||||||
if self.text_projection is not None:
|
|
||||||
x = x @ self.text_projection
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable=True):
|
|
||||||
self.grad_checkpointing = enable
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
from itertools import repeat
|
|
||||||
import collections.abc
|
|
||||||
|
|
||||||
from torch import nn as nn
|
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
|
||||||
|
|
||||||
|
|
||||||
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
|
||||||
"""
|
|
||||||
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
|
||||||
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
|
||||||
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (torch.nn.Module): Any PyTorch module.
|
|
||||||
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
|
||||||
name (str): Full module name (prefix)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.nn.Module: Resulting module
|
|
||||||
|
|
||||||
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
|
||||||
"""
|
|
||||||
res = module
|
|
||||||
is_match = True
|
|
||||||
if module_match:
|
|
||||||
is_match = name in module_match
|
|
||||||
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
|
||||||
res = FrozenBatchNorm2d(module.num_features)
|
|
||||||
res.num_features = module.num_features
|
|
||||||
res.affine = module.affine
|
|
||||||
if module.affine:
|
|
||||||
res.weight.data = module.weight.data.clone().detach()
|
|
||||||
res.bias.data = module.bias.data.clone().detach()
|
|
||||||
res.running_mean.data = module.running_mean.data
|
|
||||||
res.running_var.data = module.running_var.data
|
|
||||||
res.eps = module.eps
|
|
||||||
else:
|
|
||||||
for child_name, child in module.named_children():
|
|
||||||
full_child_name = '.'.join([name, child_name]) if name else child_name
|
|
||||||
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
|
||||||
if new_child is not child:
|
|
||||||
res.add_module(child_name, new_child)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
|
||||||
def _ntuple(n):
|
|
||||||
def parse(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return tuple(repeat(x, n))
|
|
||||||
return parse
|
|
||||||
|
|
||||||
|
|
||||||
to_1tuple = _ntuple(1)
|
|
||||||
to_2tuple = _ntuple(2)
|
|
||||||
to_3tuple = _ntuple(3)
|
|
||||||
to_4tuple = _ntuple(4)
|
|
||||||
to_ntuple = lambda n, x: _ntuple(n)(x)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
__version__ = '2.16.0'
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoProcessor, AutoModel
|
|
||||||
from typing import List, Union
|
|
||||||
import os
|
|
||||||
from .config import MODEL_PATHS
|
|
||||||
|
|
||||||
class PickScore(torch.nn.Module):
|
|
||||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
|
||||||
super().__init__()
|
|
||||||
"""Initialize the Selector with a processor and model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (Union[str, torch.device]): The device to load the model on.
|
|
||||||
"""
|
|
||||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
|
||||||
processor_name_or_path = path.get("clip")
|
|
||||||
model_pretrained_name_or_path = path.get("pickscore")
|
|
||||||
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
|
||||||
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
|
||||||
|
|
||||||
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
|
|
||||||
"""Calculate the score for a single image and prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (torch.Tensor): The processed image tensor.
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
softmax (bool): Whether to apply softmax to the scores.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The score for the image.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
# Prepare text inputs
|
|
||||||
text_inputs = self.processor(
|
|
||||||
text=prompt,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# Embed images and text
|
|
||||||
image_embs = self.model.get_image_features(pixel_values=image)
|
|
||||||
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
|
||||||
text_embs = self.model.get_text_features(**text_inputs)
|
|
||||||
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Compute score
|
|
||||||
score = (text_embs @ image_embs.T)[0]
|
|
||||||
if softmax:
|
|
||||||
# Apply logit scale and softmax
|
|
||||||
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
|
|
||||||
|
|
||||||
return score.cpu().item()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
|
||||||
"""Score the images based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
||||||
prompt (str): The prompt text.
|
|
||||||
softmax (bool): Whether to apply softmax to the scores.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: List of scores for the images.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if isinstance(images, (str, Image.Image)):
|
|
||||||
# Single image
|
|
||||||
if isinstance(images, str):
|
|
||||||
pil_image = Image.open(images)
|
|
||||||
else:
|
|
||||||
pil_image = images
|
|
||||||
|
|
||||||
# Prepare image inputs
|
|
||||||
image_inputs = self.processor(
|
|
||||||
images=pil_image,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_image in images:
|
|
||||||
if isinstance(one_image, str):
|
|
||||||
pil_image = Image.open(one_image)
|
|
||||||
elif isinstance(one_image, Image.Image):
|
|
||||||
pil_image = one_image
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
|
|
||||||
# Prepare image inputs
|
|
||||||
image_inputs = self.processor(
|
|
||||||
images=pil_image,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
|
|
||||||
return scores
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter images is illegal.")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error in scoring images: {e}")
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .models import *
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .base_model import *
|
|
||||||
from .clip_model import *
|
|
||||||
from .cross_modeling import *
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelConfig:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from transformers import CLIPModel as HFCLIPModel
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from torch import nn, einsum
|
|
||||||
|
|
||||||
from .base_model import BaseModelConfig
|
|
||||||
|
|
||||||
from transformers import CLIPConfig
|
|
||||||
from typing import Any, Optional, Tuple, Union
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .cross_modeling import Cross_model
|
|
||||||
|
|
||||||
import json, os
|
|
||||||
|
|
||||||
class XCLIPModel(HFCLIPModel):
|
|
||||||
def __init__(self, config: CLIPConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
def get_text_features(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
|
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
text_outputs = self.text_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pooled_output = text_outputs[1]
|
|
||||||
# text_features = self.text_projection(pooled_output)
|
|
||||||
last_hidden_state = text_outputs[0]
|
|
||||||
text_features = self.text_projection(last_hidden_state)
|
|
||||||
|
|
||||||
pooled_output = text_outputs[1]
|
|
||||||
text_features_EOS = self.text_projection(pooled_output)
|
|
||||||
|
|
||||||
|
|
||||||
# del last_hidden_state, text_outputs
|
|
||||||
# gc.collect()
|
|
||||||
|
|
||||||
return text_features, text_features_EOS
|
|
||||||
|
|
||||||
def get_image_features(
|
|
||||||
self,
|
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
|
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pooled_output = vision_outputs[1] # pooled_output
|
|
||||||
# image_features = self.visual_projection(pooled_output)
|
|
||||||
last_hidden_state = vision_outputs[0]
|
|
||||||
image_features = self.visual_projection(last_hidden_state)
|
|
||||||
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipModelConfig(BaseModelConfig):
|
|
||||||
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
|
|
||||||
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPModel(nn.Module):
|
|
||||||
def __init__(self, ckpt, config_file=False):
|
|
||||||
super().__init__()
|
|
||||||
if config_file is None:
|
|
||||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
|
||||||
else:
|
|
||||||
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
config = CLIPConfig(**config)
|
|
||||||
self.model = XCLIPModel._from_config(config)
|
|
||||||
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
|
||||||
|
|
||||||
def get_text_features(self, *args, **kwargs):
|
|
||||||
return self.model.get_text_features(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_image_features(self, *args, **kwargs):
|
|
||||||
return self.model.get_image_features(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
|
||||||
outputs = ()
|
|
||||||
|
|
||||||
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
|
||||||
outputs += text_EOS,
|
|
||||||
|
|
||||||
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
|
||||||
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
|
||||||
|
|
||||||
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
|
||||||
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
|
||||||
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
|
||||||
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
|
||||||
|
|
||||||
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
|
||||||
bc = int(image_f.shape[0]/2)
|
|
||||||
|
|
||||||
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
|
||||||
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
|
||||||
outputs += sim0[:,0,:],
|
|
||||||
outputs += sim1[:,0,:],
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def logit_scale(self):
|
|
||||||
return self.model.logit_scale
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
self.model.save_pretrained(path)
|
|
||||||
|
|
||||||
@@ -1,292 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import einsum, nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
# helper functions
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
return val if exists(val) else d
|
|
||||||
|
|
||||||
# normalization
|
|
||||||
# they use layernorm without bias, something that pytorch does not offer
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.register_buffer("bias", torch.zeros(dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
|
||||||
|
|
||||||
# residual
|
|
||||||
|
|
||||||
|
|
||||||
class Residual(nn.Module):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
|
||||||
return self.fn(x, *args, **kwargs) + x
|
|
||||||
|
|
||||||
|
|
||||||
# rotary positional embedding
|
|
||||||
# https://arxiv.org/abs/2104.09864
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
|
|
||||||
def forward(self, max_seq_len, *, device):
|
|
||||||
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
|
||||||
return torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
|
||||||
x1, x2 = x.unbind(dim=-2)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(pos, t):
|
|
||||||
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
|
||||||
|
|
||||||
|
|
||||||
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
|
||||||
# https://arxiv.org/abs/2002.05202
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = x.chunk(2, dim=-1)
|
|
||||||
return F.silu(gate) * x
|
|
||||||
|
|
||||||
|
|
||||||
# parallel attention and feedforward with residual
|
|
||||||
# discovered by Wang et al + EleutherAI from GPT-J fame
|
|
||||||
|
|
||||||
class ParallelTransformerBlock(nn.Module):
|
|
||||||
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
|
||||||
super().__init__()
|
|
||||||
self.norm = LayerNorm(dim)
|
|
||||||
|
|
||||||
attn_inner_dim = dim_head * heads
|
|
||||||
ff_inner_dim = dim * ff_mult
|
|
||||||
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
|
||||||
|
|
||||||
self.heads = heads
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.rotary_emb = RotaryEmbedding(dim_head)
|
|
||||||
|
|
||||||
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
|
||||||
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
|
||||||
|
|
||||||
self.ff_out = nn.Sequential(
|
|
||||||
SwiGLU(),
|
|
||||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_buffer("pos_emb", None, persistent=False)
|
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_embedding(self, n, device):
|
|
||||||
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
|
||||||
return self.pos_emb[:n]
|
|
||||||
|
|
||||||
pos_emb = self.rotary_emb(n, device=device)
|
|
||||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
|
||||||
return pos_emb
|
|
||||||
|
|
||||||
def forward(self, x, attn_mask=None):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
h - heads
|
|
||||||
n, i, j - sequence length (base sequence length, source, target)
|
|
||||||
d - feature dimension
|
|
||||||
"""
|
|
||||||
|
|
||||||
n, device, h = x.shape[1], x.device, self.heads
|
|
||||||
|
|
||||||
# pre layernorm
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
# attention queries, keys, values, and feedforward inner
|
|
||||||
|
|
||||||
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
|
||||||
|
|
||||||
# split heads
|
|
||||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
|
||||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
|
||||||
# https://arxiv.org/abs/1911.02150
|
|
||||||
|
|
||||||
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
|
||||||
|
|
||||||
# rotary embeddings
|
|
||||||
|
|
||||||
positions = self.get_rotary_embedding(n, device)
|
|
||||||
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
|
||||||
|
|
||||||
# scale
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
|
|
||||||
# similarity
|
|
||||||
|
|
||||||
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
|
||||||
|
|
||||||
|
|
||||||
# extra attention mask - for masking out attention from text CLS token to padding
|
|
||||||
|
|
||||||
if exists(attn_mask):
|
|
||||||
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
|
||||||
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
|
|
||||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
||||||
attn = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
# aggregate values
|
|
||||||
|
|
||||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
|
||||||
|
|
||||||
# merge heads
|
|
||||||
|
|
||||||
out = rearrange(out, "b h n d -> b n (h d)")
|
|
||||||
return self.attn_out(out) + self.ff_out(ff)
|
|
||||||
|
|
||||||
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
*,
|
|
||||||
context_dim=None,
|
|
||||||
dim_head=64,
|
|
||||||
heads=12,
|
|
||||||
parallel_ff=False,
|
|
||||||
ff_mult=4,
|
|
||||||
norm_context=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.heads = heads
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
inner_dim = heads * dim_head
|
|
||||||
context_dim = default(context_dim, dim)
|
|
||||||
|
|
||||||
self.norm = LayerNorm(dim)
|
|
||||||
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
||||||
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
||||||
|
|
||||||
# whether to have parallel feedforward
|
|
||||||
|
|
||||||
ff_inner_dim = ff_mult * dim
|
|
||||||
|
|
||||||
self.ff = nn.Sequential(
|
|
||||||
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
|
||||||
SwiGLU(),
|
|
||||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
|
||||||
) if parallel_ff else None
|
|
||||||
|
|
||||||
def forward(self, x, context, mask):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
h - heads
|
|
||||||
n, i, j - sequence length (base sequence length, source, target)
|
|
||||||
d - feature dimension
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pre-layernorm, for queries and context
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
context = self.context_norm(context)
|
|
||||||
|
|
||||||
# get queries
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
|
||||||
|
|
||||||
# scale
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
|
|
||||||
# get key / values
|
|
||||||
|
|
||||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
|
||||||
|
|
||||||
# query / key similarity
|
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
|
||||||
sim = sim + mask # context mask
|
|
||||||
sim = sim - sim.amax(dim=-1, keepdim=True)
|
|
||||||
attn = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
# aggregate
|
|
||||||
|
|
||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
|
||||||
|
|
||||||
# merge and combine heads
|
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
||||||
out = self.to_out(out)
|
|
||||||
|
|
||||||
# add parallel feedforward (for multimodal layers)
|
|
||||||
|
|
||||||
if exists(self.ff):
|
|
||||||
out = out + self.ff(x)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Cross_model(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim=512,
|
|
||||||
layer_num=4,
|
|
||||||
dim_head=64,
|
|
||||||
heads=8,
|
|
||||||
ff_mult=4
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
|
|
||||||
for ind in range(layer_num):
|
|
||||||
self.layers.append(nn.ModuleList([
|
|
||||||
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
|
||||||
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query_tokens,
|
|
||||||
context_tokens,
|
|
||||||
mask
|
|
||||||
):
|
|
||||||
|
|
||||||
for cross_attn, self_attn_ff in self.layers:
|
|
||||||
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
|
||||||
query_tokens = self_attn_ff(query_tokens)
|
|
||||||
|
|
||||||
return query_tokens
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def warp(tenInput, tenFlow, device):
|
|
||||||
backwarp_tenGrid = {}
|
|
||||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
|
||||||
if k not in backwarp_tenGrid:
|
|
||||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
|
||||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
|
||||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
|
||||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
|
||||||
backwarp_tenGrid[k] = torch.cat(
|
|
||||||
[tenHorizontal, tenVertical], 1).to(device)
|
|
||||||
|
|
||||||
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
|
||||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
|
||||||
|
|
||||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
|
||||||
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
|
||||||
|
|
||||||
|
|
||||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, bias=True),
|
|
||||||
nn.PReLU(out_planes)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IFBlock(nn.Module):
|
|
||||||
def __init__(self, in_planes, c=64):
|
|
||||||
super(IFBlock, self).__init__()
|
|
||||||
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
|
||||||
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
|
||||||
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
|
||||||
|
|
||||||
def forward(self, x, flow, scale=1):
|
|
||||||
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
||||||
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
|
||||||
feat = self.conv0(torch.cat((x, flow), 1))
|
|
||||||
feat = self.convblock0(feat) + feat
|
|
||||||
feat = self.convblock1(feat) + feat
|
|
||||||
feat = self.convblock2(feat) + feat
|
|
||||||
feat = self.convblock3(feat) + feat
|
|
||||||
flow = self.conv1(feat)
|
|
||||||
mask = self.conv2(feat)
|
|
||||||
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
|
||||||
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
||||||
return flow, mask
|
|
||||||
|
|
||||||
|
|
||||||
class IFNet(nn.Module):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(IFNet, self).__init__()
|
|
||||||
self.block0 = IFBlock(7+4, c=90)
|
|
||||||
self.block1 = IFBlock(7+4, c=90)
|
|
||||||
self.block2 = IFBlock(7+4, c=90)
|
|
||||||
self.block_tea = IFBlock(10+4, c=90)
|
|
||||||
|
|
||||||
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
|
||||||
if training == False:
|
|
||||||
channel = x.shape[1] // 2
|
|
||||||
img0 = x[:, :channel]
|
|
||||||
img1 = x[:, channel:]
|
|
||||||
flow_list = []
|
|
||||||
merged = []
|
|
||||||
mask_list = []
|
|
||||||
warped_img0 = img0
|
|
||||||
warped_img1 = img1
|
|
||||||
flow = (x[:, :4]).detach() * 0
|
|
||||||
mask = (x[:, :1]).detach() * 0
|
|
||||||
block = [self.block0, self.block1, self.block2]
|
|
||||||
for i in range(3):
|
|
||||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
|
||||||
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
|
||||||
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
|
||||||
mask = mask + (m0 + (-m1)) / 2
|
|
||||||
mask_list.append(mask)
|
|
||||||
flow_list.append(flow)
|
|
||||||
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
|
||||||
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
|
||||||
merged.append((warped_img0, warped_img1))
|
|
||||||
'''
|
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
||||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
||||||
res = tmp[:, 1:4] * 2 - 1
|
|
||||||
'''
|
|
||||||
for i in range(3):
|
|
||||||
mask_list[i] = torch.sigmoid(mask_list[i])
|
|
||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
|
||||||
return flow_list, mask_list[2], merged
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return IFNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class IFNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
|
|
||||||
class RIFEInterpolater:
|
|
||||||
def __init__(self, model, device="cuda"):
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
# IFNet only does not support float16
|
|
||||||
self.torch_dtype = torch.float32
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
|
||||||
|
|
||||||
def process_image(self, image):
|
|
||||||
width, height = image.size
|
|
||||||
if width % 32 != 0 or height % 32 != 0:
|
|
||||||
width = (width + 31) // 32
|
|
||||||
height = (height + 31) // 32
|
|
||||||
image = image.resize((width, height))
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process_images(self, images):
|
|
||||||
images = [self.process_image(image) for image in images]
|
|
||||||
images = torch.stack(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
def decode_images(self, images):
|
|
||||||
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
||||||
images = [Image.fromarray(image) for image in images]
|
|
||||||
return images
|
|
||||||
|
|
||||||
def add_interpolated_images(self, images, interpolated_images):
|
|
||||||
output_images = []
|
|
||||||
for image, interpolated_image in zip(images, interpolated_images):
|
|
||||||
output_images.append(image)
|
|
||||||
output_images.append(interpolated_image)
|
|
||||||
output_images.append(images[-1])
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate_(self, images, scale=1.0):
|
|
||||||
input_tensor = self.process_images(images)
|
|
||||||
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
|
||||||
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_images = self.decode_images(merged[2].cpu())
|
|
||||||
if output_images[0].size != images[0].size:
|
|
||||||
output_images = [image.resize(images[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
|
||||||
# Preprocess
|
|
||||||
processed_images = self.process_images(images)
|
|
||||||
|
|
||||||
for iter in range(num_iter):
|
|
||||||
# Input
|
|
||||||
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_tensor.append(merged[2].cpu())
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
|
||||||
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
|
||||||
processed_images = torch.stack(processed_images)
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(processed_images)
|
|
||||||
if output_images[0].size != images[0].size:
|
|
||||||
output_images = [image.resize(images[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class RIFESmoother(RIFEInterpolater):
|
|
||||||
def __init__(self, model, device="cuda"):
|
|
||||||
super(RIFESmoother, self).__init__(model, device=device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
|
||||||
|
|
||||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_tensor.append(merged[2].cpu())
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
|
||||||
# Preprocess
|
|
||||||
processed_images = self.process_images(rendered_frames)
|
|
||||||
|
|
||||||
for iter in range(num_iter):
|
|
||||||
# Input
|
|
||||||
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Blend
|
|
||||||
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
|
||||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Add to frames
|
|
||||||
processed_images[1:-1] = output_tensor
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(processed_images)
|
|
||||||
if output_images[0].size != rendered_frames[0].size:
|
|
||||||
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .model_manager import *
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
def low_version_attention(query, key, value, attn_bias=None):
|
|
||||||
scale = 1 / query.shape[-1] ** 0.5
|
|
||||||
query = query * scale
|
|
||||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
|
||||||
if attn_bias is not None:
|
|
||||||
attn = attn + attn_bias
|
|
||||||
attn = attn.softmax(-1)
|
|
||||||
return attn @ value
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
|
||||||
super().__init__()
|
|
||||||
dim_inner = head_dim * num_heads
|
|
||||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
|
|
||||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
|
||||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
||||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
||||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
|
||||||
|
|
||||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
|
||||||
batch_size = q.shape[0]
|
|
||||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
|
|
||||||
batch_size = encoder_hidden_states.shape[0]
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(encoder_hidden_states)
|
|
||||||
v = self.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if qkv_preprocessor is not None:
|
|
||||||
q, k, v = qkv_preprocessor(q, k, v)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
if ipadapter_kwargs is not None:
|
|
||||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(encoder_hidden_states)
|
|
||||||
v = self.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
|
||||||
else:
|
|
||||||
import xformers.ops as xops
|
|
||||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
|
||||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
||||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
|
||||||
@@ -1,408 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from .sd3_dit import TimestepEmbeddings
|
|
||||||
from .attention import Attention
|
|
||||||
from .utils import load_state_dict_from_folder
|
|
||||||
from .tiler import TileWorker2Dto3D
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogPatchify(torch.nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.proj(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogAdaLayerNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_cond, single=False):
|
|
||||||
super().__init__()
|
|
||||||
self.single = single
|
|
||||||
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, prompt_emb, emb):
|
|
||||||
emb = self.linear(torch.nn.functional.silu(emb))
|
|
||||||
if self.single:
|
|
||||||
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
|
||||||
return hidden_states
|
|
||||||
else:
|
|
||||||
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
|
||||||
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
|
||||||
return hidden_states, prompt_emb, gate_a, gate_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiTBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_cond, num_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
|
||||||
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
|
||||||
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
|
||||||
|
|
||||||
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(self, x, freqs_cis):
|
|
||||||
cos, sin = freqs_cis # [S, D]
|
|
||||||
cos = cos[None, None]
|
|
||||||
sin = sin[None, None]
|
|
||||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
|
||||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
|
||||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
|
||||||
# Attention
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
|
||||||
hidden_states, prompt_emb, time_emb
|
|
||||||
)
|
|
||||||
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
attention_io = self.attn1(
|
|
||||||
attention_io,
|
|
||||||
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
|
||||||
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
|
||||||
|
|
||||||
# Feed forward
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
|
||||||
hidden_states, prompt_emb, time_emb
|
|
||||||
)
|
|
||||||
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
ff_io = self.ff(ff_io)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
|
||||||
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
|
||||||
|
|
||||||
return hidden_states, prompt_emb
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiT(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.patchify = CogPatchify(16, 3072, 2)
|
|
||||||
self.time_embedder = TimestepEmbeddings(3072, 512)
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
|
||||||
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
|
||||||
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
|
||||||
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
|
||||||
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
|
||||||
tw = tgt_width
|
|
||||||
th = tgt_height
|
|
||||||
h, w = src
|
|
||||||
r = h / w
|
|
||||||
if r > (th / tw):
|
|
||||||
resize_height = th
|
|
||||||
resize_width = int(round(th / h * w))
|
|
||||||
else:
|
|
||||||
resize_width = tw
|
|
||||||
resize_height = int(round(tw / w * h))
|
|
||||||
|
|
||||||
crop_top = int(round((th - resize_height) / 2.0))
|
|
||||||
crop_left = int(round((tw - resize_width) / 2.0))
|
|
||||||
|
|
||||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
|
||||||
|
|
||||||
|
|
||||||
def get_3d_rotary_pos_embed(
|
|
||||||
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
|
||||||
):
|
|
||||||
start, stop = crops_coords
|
|
||||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
|
||||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
|
||||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
|
||||||
|
|
||||||
# Compute dimensions for each axis
|
|
||||||
dim_t = embed_dim // 4
|
|
||||||
dim_h = embed_dim // 8 * 3
|
|
||||||
dim_w = embed_dim // 8 * 3
|
|
||||||
|
|
||||||
# Temporal frequencies
|
|
||||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
|
||||||
grid_t = torch.from_numpy(grid_t).float()
|
|
||||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
|
||||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Spatial frequencies for height and width
|
|
||||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
|
||||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
|
||||||
grid_h = torch.from_numpy(grid_h).float()
|
|
||||||
grid_w = torch.from_numpy(grid_w).float()
|
|
||||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
|
||||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
|
||||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
|
||||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Broadcast and concatenate tensors along specified dimension
|
|
||||||
def broadcast(tensors, dim=-1):
|
|
||||||
num_tensors = len(tensors)
|
|
||||||
shape_lens = {len(t.shape) for t in tensors}
|
|
||||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
|
||||||
shape_len = list(shape_lens)[0]
|
|
||||||
dim = (dim + shape_len) if dim < 0 else dim
|
|
||||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
|
||||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
|
||||||
assert all(
|
|
||||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
|
||||||
), "invalid dimensions for broadcastable concatenation"
|
|
||||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
|
||||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
|
||||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
|
||||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
|
||||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
|
||||||
return torch.cat(tensors, dim=dim)
|
|
||||||
|
|
||||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
|
||||||
|
|
||||||
t, h, w, d = freqs.shape
|
|
||||||
freqs = freqs.view(t * h * w, d)
|
|
||||||
|
|
||||||
# Generate sine and cosine components
|
|
||||||
sin = freqs.sin()
|
|
||||||
cos = freqs.cos()
|
|
||||||
|
|
||||||
if use_real:
|
|
||||||
return cos, sin
|
|
||||||
else:
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_rotary_positional_embeddings(
|
|
||||||
self,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
num_frames: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
grid_height = height // 2
|
|
||||||
grid_width = width // 2
|
|
||||||
base_size_width = 720 // (8 * 2)
|
|
||||||
base_size_height = 480 // (8 * 2)
|
|
||||||
|
|
||||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
|
||||||
(grid_height, grid_width), base_size_width, base_size_height
|
|
||||||
)
|
|
||||||
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
|
||||||
embed_dim=64,
|
|
||||||
crops_coords=grid_crops_coords,
|
|
||||||
grid_size=(grid_height, grid_width),
|
|
||||||
temporal_size=num_frames,
|
|
||||||
use_real=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
|
||||||
freqs_sin = freqs_sin.to(device=device)
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
|
|
||||||
|
|
||||||
def unpatchify(self, hidden_states, height, width):
|
|
||||||
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, T, H, W, dtype, device, is_bound):
|
|
||||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
|
||||||
border_width = (H + W) // 4
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else t + 1,
|
|
||||||
pad if is_bound[1] else T - t,
|
|
||||||
pad if is_bound[2] else h + 1,
|
|
||||||
pad if is_bound[3] else H - h,
|
|
||||||
pad if is_bound[4] else w + 1,
|
|
||||||
pad if is_bound[5] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride):
|
|
||||||
for w in range(0, W, tile_stride):
|
|
||||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size, w + tile_size
|
|
||||||
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
|
||||||
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in tasks:
|
|
||||||
mask = self.build_mask(
|
|
||||||
value.shape[2], (hr-hl), (wr-wl),
|
|
||||||
hidden_states.dtype, hidden_states.device,
|
|
||||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
|
||||||
)
|
|
||||||
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
|
||||||
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
|
||||||
weight[:, :, :, hl:hr, wl:wr] += mask
|
|
||||||
value = value / weight
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
|
||||||
if tiled:
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
|
||||||
model_input=hidden_states,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
|
||||||
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
|
||||||
)
|
|
||||||
num_frames, height, width = hidden_states.shape[-3:]
|
|
||||||
if image_rotary_emb is None:
|
|
||||||
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
|
||||||
hidden_states = self.patchify(hidden_states)
|
|
||||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
|
||||||
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
|
||||||
model = CogDiT().to(torch_dtype)
|
|
||||||
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
|
||||||
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"patch_embed.proj.weight": "patchify.proj.weight",
|
|
||||||
"patch_embed.proj.bias": "patchify.proj.bias",
|
|
||||||
"patch_embed.text_proj.weight": "context_embedder.weight",
|
|
||||||
"patch_embed.text_proj.bias": "context_embedder.bias",
|
|
||||||
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
|
|
||||||
"norm_final.weight": "norm_final.weight",
|
|
||||||
"norm_final.bias": "norm_final.bias",
|
|
||||||
"norm_out.linear.weight": "norm_out.linear.weight",
|
|
||||||
"norm_out.linear.bias": "norm_out.linear.bias",
|
|
||||||
"norm_out.norm.weight": "norm_out.norm.weight",
|
|
||||||
"norm_out.norm.bias": "norm_out.norm.bias",
|
|
||||||
"proj_out.weight": "proj_out.weight",
|
|
||||||
"proj_out.bias": "proj_out.bias",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.linear.weight": "norm1.linear.weight",
|
|
||||||
"norm1.linear.bias": "norm1.linear.bias",
|
|
||||||
"norm1.norm.weight": "norm1.norm.weight",
|
|
||||||
"norm1.norm.bias": "norm1.norm.bias",
|
|
||||||
"attn1.norm_q.weight": "norm_q.weight",
|
|
||||||
"attn1.norm_q.bias": "norm_q.bias",
|
|
||||||
"attn1.norm_k.weight": "norm_k.weight",
|
|
||||||
"attn1.norm_k.bias": "norm_k.bias",
|
|
||||||
"attn1.to_q.weight": "attn1.to_q.weight",
|
|
||||||
"attn1.to_q.bias": "attn1.to_q.bias",
|
|
||||||
"attn1.to_k.weight": "attn1.to_k.weight",
|
|
||||||
"attn1.to_k.bias": "attn1.to_k.bias",
|
|
||||||
"attn1.to_v.weight": "attn1.to_v.weight",
|
|
||||||
"attn1.to_v.bias": "attn1.to_v.bias",
|
|
||||||
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
|
||||||
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
|
||||||
"norm2.linear.weight": "norm2.linear.weight",
|
|
||||||
"norm2.linear.bias": "norm2.linear.bias",
|
|
||||||
"norm2.norm.weight": "norm2.norm.weight",
|
|
||||||
"norm2.norm.bias": "norm2.norm.bias",
|
|
||||||
"ff.net.0.proj.weight": "ff.0.weight",
|
|
||||||
"ff.net.0.proj.bias": "ff.0.bias",
|
|
||||||
"ff.net.2.weight": "ff.2.weight",
|
|
||||||
"ff.net.2.bias": "ff.2.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
if name == "patch_embed.proj.weight":
|
|
||||||
param = param.unsqueeze(2)
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
names = name.split(".")
|
|
||||||
if names[0] == "transformer_blocks":
|
|
||||||
suffix = ".".join(names[2:])
|
|
||||||
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user