mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
527 Commits
ExVideo
...
value-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba421a9ab9 | ||
|
|
6c30a7f080 | ||
|
|
1363a0559f | ||
|
|
9bb51fe879 | ||
|
|
d9c812818d | ||
|
|
c8e9a96196 | ||
|
|
6143af4654 | ||
|
|
9458e382b0 | ||
|
|
4f2d9226cf | ||
|
|
f688a469b1 | ||
|
|
c8ea3b3356 | ||
|
|
6e9472b470 | ||
|
|
a5c03c5272 | ||
|
|
8068ac2592 | ||
|
|
5f80e7ac5e | ||
|
|
157e0be49d | ||
|
|
3dbe271aab | ||
|
|
44e2eecdf1 | ||
|
|
8c226e83a6 | ||
|
|
009f26bb40 | ||
|
|
fcf2fbc07f | ||
|
|
b603acd36a | ||
|
|
6c8bb6438b | ||
|
|
8072d3839d | ||
|
|
c8ad643374 | ||
|
|
31f9df5e62 | ||
|
|
e2f415524a | ||
|
|
3eb7e7530e | ||
|
|
916aa54595 | ||
|
|
6ddbd43f7b | ||
|
|
a37a83ecc3 | ||
|
|
f2a0d0c85f | ||
|
|
93194f44e8 | ||
|
|
c4e5033532 | ||
|
|
cc6cd26733 | ||
|
|
1113d305d1 | ||
|
|
6d5f8b7423 | ||
|
|
1b3c204d20 | ||
|
|
1788d50f0a | ||
|
|
e7a21dbf0b | ||
|
|
3b3e1e4d44 | ||
|
|
24426e3a32 | ||
|
|
31369bab15 | ||
|
|
551721658b | ||
|
|
46f052375f | ||
|
|
c2d35a2157 | ||
|
|
4c052e42bc | ||
|
|
a88613555d | ||
|
|
c164519ef1 | ||
|
|
afff5ffb21 | ||
|
|
a8481fd5e1 | ||
|
|
8584e50309 | ||
|
|
9f3e02f167 | ||
|
|
7ad9b9aecc | ||
|
|
b6a111d3a2 | ||
|
|
bd6f2695a9 | ||
|
|
6eecc9d442 | ||
|
|
35269783d7 | ||
|
|
9534a78167 | ||
|
|
830b1b7202 | ||
|
|
436a91e0c9 | ||
|
|
40760ab88b | ||
|
|
8badd63a2d | ||
|
|
b1afff1728 | ||
|
|
6e977e1181 | ||
|
|
62f6ca2b8a | ||
|
|
4e00c109e3 | ||
|
|
8f10a9c353 | ||
|
|
a3a35acc7e | ||
|
|
675eefa07e | ||
|
|
dbef6122e9 | ||
|
|
d150bcf622 | ||
|
|
451aab0116 | ||
|
|
3edf3583b1 | ||
|
|
ef2a7abad4 | ||
|
|
32f630ff5f | ||
|
|
109a0a0d49 | ||
|
|
4f01b37a2a | ||
|
|
cc6306136c | ||
|
|
419ace37f3 | ||
|
|
ccf24c363f | ||
|
|
b7a1ac6671 | ||
|
|
e54c0a8468 | ||
|
|
5f4cb32255 | ||
|
|
7b6cf39618 | ||
|
|
bf81de0c88 | ||
|
|
b36cad6929 | ||
|
|
b161bd6dfd | ||
|
|
538cfcbb77 | ||
|
|
a4105d2c0e | ||
|
|
553b341f5f | ||
|
|
e9e24b8cf1 | ||
|
|
1b693d0028 | ||
|
|
a4c3c07229 | ||
|
|
6b24748c80 | ||
|
|
8f2f8646eb | ||
|
|
e3ac438f5a | ||
|
|
b731628112 | ||
|
|
0dc56d9dcc | ||
|
|
b925b402e2 | ||
|
|
61d9653536 | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
a66203a391 | ||
|
|
fab61f614b | ||
|
|
6b67a11ad6 | ||
|
|
91f77d268c | ||
|
|
eb4d5187d8 | ||
|
|
ee4b02247c | ||
|
|
da8e1fe7e4 | ||
|
|
3db824c281 | ||
|
|
df2ecafd3f | ||
|
|
217652d28e | ||
|
|
f64c766dcd | ||
|
|
076fd85556 | ||
|
|
c7912ed827 | ||
|
|
e63f9d6993 | ||
|
|
d80ef3a677 | ||
|
|
852c3d831f | ||
|
|
ceb92ee7aa | ||
|
|
3a75026176 | ||
|
|
6a92b08244 | ||
|
|
38bc785ea9 | ||
|
|
a466fdca8f | ||
|
|
f9f49e3c78 | ||
|
|
61a30673c2 | ||
|
|
a48822ec00 | ||
|
|
b6c3d2b74a | ||
|
|
5006c2176c | ||
|
|
d3d3556ff6 | ||
|
|
6fa8dbe077 | ||
|
|
a57749ef60 | ||
|
|
b5c1d33e58 | ||
|
|
34a9f82865 | ||
|
|
18dc6cb962 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
c760208614 | ||
|
|
fad7aea58a | ||
|
|
b42eb1444c | ||
|
|
25a247dd3f | ||
|
|
7792017a02 | ||
|
|
0219e8d2f3 | ||
|
|
1d309a14a3 | ||
|
|
7df73ceaaf | ||
|
|
0dbb3d333f | ||
|
|
1419bec53d | ||
|
|
cf12723c89 | ||
|
|
4268f5466b | ||
|
|
b9f5a00d98 | ||
|
|
7d44dc99fb | ||
|
|
b20de1b44d | ||
|
|
366ee0f542 | ||
|
|
bed770248b | ||
|
|
020560d2b5 | ||
|
|
af7d305f00 | ||
|
|
427232cbc0 | ||
|
|
2899283c01 | ||
|
|
9cff769fbd | ||
|
|
23e33273f1 | ||
|
|
f191353cf4 | ||
|
|
66a094fc84 | ||
|
|
3681adc5ac | ||
|
|
4449faaa01 | ||
|
|
991ba162bd | ||
|
|
77d0f4d297 | ||
|
|
a834371d50 | ||
|
|
acda7d891a | ||
|
|
7434ec8fcd | ||
|
|
0699212665 | ||
|
|
f47de78b59 | ||
|
|
5fdc8039ec | ||
|
|
46d4616e23 | ||
|
|
2e597335be | ||
|
|
d346300162 | ||
|
|
1df7387f1b | ||
|
|
75d62a02d1 | ||
|
|
9db26879df | ||
|
|
7beac7972e | ||
|
|
72cac18d3e | ||
|
|
9f8112ec34 | ||
|
|
d9fad821b2 | ||
|
|
c0889c2564 | ||
|
|
913591c13e | ||
|
|
aaf13d6e4a | ||
|
|
90c07fec61 | ||
|
|
cc6c3c0807 | ||
|
|
ce2476ab9b | ||
|
|
9e70c49317 | ||
|
|
bf1c99645b | ||
|
|
c2478ff284 | ||
|
|
a60bf3cd5f | ||
|
|
34231907d0 | ||
|
|
840dab58cd | ||
|
|
d5ceca0663 | ||
|
|
8cf3422688 | ||
|
|
6f743fc4b6 | ||
|
|
991b133bff | ||
|
|
3b010043de | ||
|
|
088ea29e6e | ||
|
|
b8b135ff73 | ||
|
|
2872fdaf48 | ||
|
|
9853f83454 | ||
|
|
fd6e661203 | ||
|
|
c087f68d74 | ||
|
|
b6620f3dde | ||
|
|
3228c3e085 | ||
|
|
6cc5fd6d1e | ||
|
|
4f6d5e7074 | ||
|
|
6a999e1127 | ||
|
|
e3d89cec0c | ||
|
|
1b6e96a820 | ||
|
|
e38ccf4c2f | ||
|
|
010c801081 | ||
|
|
edc9272e55 | ||
|
|
405ca6be33 | ||
|
|
c06ea2271a | ||
|
|
0692e8b1e1 | ||
|
|
aa23356420 | ||
|
|
00a610e5ad | ||
|
|
2e39dcc0d3 | ||
|
|
03d3a26f6f | ||
|
|
309fa9cf51 | ||
|
|
65aab8adea | ||
|
|
3d48b287a3 | ||
|
|
29cebf0bec | ||
|
|
95a0f0bedc | ||
|
|
77e0617861 | ||
|
|
469a0405a1 | ||
|
|
46f191ffe7 | ||
|
|
ec7ac20def | ||
|
|
3f410b0b77 | ||
|
|
8e06cac0df | ||
|
|
e5099f4e74 | ||
|
|
447adef472 | ||
|
|
a849b05e5a | ||
|
|
b048f1b1de | ||
|
|
f7848f9560 | ||
|
|
236b56d285 | ||
|
|
42a717054a | ||
|
|
263166768e | ||
|
|
7a45b7efa7 | ||
|
|
54ed532e3e | ||
|
|
05e2028c5d | ||
|
|
79249063b8 | ||
|
|
31ebec7a72 | ||
|
|
919d399fdb | ||
|
|
32a7a1487d | ||
|
|
8c2671ce40 | ||
|
|
5d1005a7c8 | ||
|
|
b84f906964 | ||
|
|
7c0520d029 | ||
|
|
9d09121fbc | ||
|
|
7f2a5424d4 | ||
|
|
00830f0ecd | ||
|
|
fd7737af7d | ||
|
|
f2130c4c25 | ||
|
|
4f40683fd8 | ||
|
|
5fc9e53eec | ||
|
|
27e3cea285 | ||
|
|
ee770fa68f | ||
|
|
9cb4aa16eb | ||
|
|
92d990629f | ||
|
|
ba58f1bc0b | ||
|
|
02fcfd530f | ||
|
|
095e8a3de8 | ||
|
|
e17ad83fb5 | ||
|
|
e7c41151ec | ||
|
|
7f4ba62d4f | ||
|
|
71b17a3a53 | ||
|
|
d46b8b8fd7 | ||
|
|
a671070a28 | ||
|
|
4600d5351b | ||
|
|
75bba5b8e5 | ||
|
|
8d1d1536d3 | ||
|
|
a7050a185b | ||
|
|
d345541c2d | ||
|
|
bd028e4c66 | ||
|
|
d6f4fb67cc | ||
|
|
4378b540cf | ||
|
|
39ddb7c3e3 | ||
|
|
344cbd3286 | ||
|
|
d4ba173b53 | ||
|
|
c56ce656b2 | ||
|
|
9377214518 | ||
|
|
900a1c095f | ||
|
|
7e97a96840 | ||
|
|
69f272d7ba | ||
|
|
a653554bd9 | ||
|
|
6a25006544 | ||
|
|
8cfe4820f6 | ||
|
|
c8021d4224 | ||
|
|
3a64cc27b5 | ||
|
|
2edc485ec1 | ||
|
|
a6d6553cee | ||
|
|
45feef9413 | ||
|
|
105fe3961c | ||
|
|
d381c7b186 | ||
|
|
5e8334c0bf | ||
|
|
2ea8a16afb | ||
|
|
aa054db1c7 | ||
|
|
07d70a6a56 | ||
|
|
747572e62c | ||
|
|
72ed76e89e | ||
|
|
a403cb04f3 | ||
|
|
ed71184854 | ||
|
|
dfbf43e463 | ||
|
|
7d7d72dcfe | ||
|
|
540c036988 | ||
|
|
58f89ceec9 | ||
|
|
4e3a184199 | ||
|
|
22e4ae99e8 | ||
|
|
75ab786afc | ||
|
|
e5c72ba1f2 | ||
|
|
66873d7d64 | ||
|
|
a0d1d5bcea | ||
|
|
fa0fa95bb6 | ||
|
|
41ea2f811a | ||
|
|
ec352cfce2 | ||
|
|
aade874241 | ||
|
|
c01eb653d7 | ||
|
|
892f80c265 | ||
|
|
2e487a2c55 | ||
|
|
a34e3ba338 | ||
|
|
c414f4cb12 | ||
|
|
d91c603875 | ||
|
|
7f899dcfca | ||
|
|
5f12fd4346 | ||
|
|
a7197f846b | ||
|
|
ac81fa7a9f | ||
|
|
091df1f1e7 | ||
|
|
a9fbfa108f | ||
|
|
44a8bf4143 | ||
|
|
3da8aa257b | ||
|
|
884dd749a0 | ||
|
|
c697591d6e | ||
|
|
0b706e03e7 | ||
|
|
447e75cd06 | ||
|
|
7f76c8809c | ||
|
|
cde1f81df6 | ||
|
|
c21ed1e478 | ||
|
|
a8cb4a21d1 | ||
|
|
0b9e673fa2 | ||
|
|
d242af8e22 | ||
|
|
76bd931d79 | ||
|
|
995f3374f1 | ||
|
|
1887885274 | ||
|
|
ce43cf412d | ||
|
|
d1712f0594 | ||
|
|
416b73b8c0 | ||
|
|
4654aa0cab | ||
|
|
6f9d8f465a | ||
|
|
e5e55345dc | ||
|
|
8d6eb6d41a | ||
|
|
1118e67cec | ||
|
|
d70cd04b15 | ||
|
|
3d1db23224 | ||
|
|
a488810693 | ||
|
|
0b066d3cb4 | ||
|
|
d154bee18a | ||
|
|
3a8694b642 | ||
|
|
fe485b3fa1 | ||
|
|
e70eaa6a31 | ||
|
|
27ef67306d | ||
|
|
547aca3db2 | ||
|
|
5f7360e2ce | ||
|
|
23f9675218 | ||
|
|
ef1e82076c | ||
|
|
65d4588cc7 | ||
|
|
0488f90c8f | ||
|
|
03d91f6618 | ||
|
|
ae5e4b67dc | ||
|
|
a6c6e33d88 | ||
|
|
79d9bf7109 | ||
|
|
66e1b382cd | ||
|
|
66f1ff43e9 | ||
|
|
d6d14859e3 | ||
|
|
4478bb9bbe | ||
|
|
a6aaf9da2a | ||
|
|
aa908ae0c2 | ||
|
|
778a2d8f84 | ||
|
|
508baabf9a | ||
|
|
80aa4d8e19 | ||
|
|
99e11112a7 | ||
|
|
1116e6dbc7 | ||
|
|
d1ac96c1ab | ||
|
|
abe88c899e | ||
|
|
b1709fcbdb | ||
|
|
ec877bf490 | ||
|
|
a8f1812acf | ||
|
|
6877b460c4 | ||
|
|
f189f9f1be | ||
|
|
6f79fd6d77 | ||
|
|
60d7bb52d6 | ||
|
|
65a2a0643a | ||
|
|
bc5f151dfa | ||
|
|
5cd6ed0096 | ||
|
|
be84b35bfd | ||
|
|
d9fc30ffd0 | ||
|
|
8f59d00d9e | ||
|
|
3d8ff39aed | ||
|
|
b5c194df43 | ||
|
|
8680f92b60 | ||
|
|
05c97bc755 | ||
|
|
db88d60750 | ||
|
|
40c6da8075 | ||
|
|
3981b8084f | ||
|
|
9dfb7c1c37 | ||
|
|
9ed54c188e | ||
|
|
6a47a346b1 | ||
|
|
e3f8a576cf | ||
|
|
0aff733a92 | ||
|
|
9471bff8a4 | ||
|
|
3f8eea4687 | ||
|
|
b1b2d50c0d | ||
|
|
9c6607f78d | ||
|
|
2a4709e572 | ||
|
|
04f3fce3b0 | ||
|
|
be9c3524a5 | ||
|
|
c3d899dd48 | ||
|
|
6e03ee2a75 | ||
|
|
979a8814f1 | ||
|
|
8be4fad330 | ||
|
|
8113f95278 | ||
|
|
9ca6c646df | ||
|
|
466b37994e | ||
|
|
518c6d6ac3 | ||
|
|
9920b8d975 | ||
|
|
237daa2048 | ||
|
|
e9af28e6a3 | ||
|
|
996515c7ca | ||
|
|
c2ccc39e3c | ||
|
|
ad24b93431 | ||
|
|
bd5fc32d79 | ||
|
|
03cefe8f58 | ||
|
|
64339f7089 | ||
|
|
0b1704976a | ||
|
|
0af60b9c73 | ||
|
|
280f0eacc0 | ||
|
|
03cba5e59e | ||
|
|
fa0ea0e1a4 | ||
|
|
40d24b8907 | ||
|
|
1bf02f439f | ||
|
|
0489c62550 | ||
|
|
ad98602da3 | ||
|
|
fb12ac316a | ||
|
|
e9ec2f2706 | ||
|
|
00f294454b | ||
|
|
0465d940c7 | ||
|
|
2c549598d0 | ||
|
|
7d33082d70 |
29
.github/workflows/publish.yaml
vendored
Normal file
29
.github/workflows/publish.yaml
vendored
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-n-publish:
|
||||||
|
runs-on: ubuntu-20.04
|
||||||
|
#if: startsWith(github.event.ref, 'refs/tags')
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install wheel
|
||||||
|
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||||
|
- name: Build DiffSynth
|
||||||
|
run: python setup.py sdist bdist_wheel
|
||||||
|
- name: Publish package to PyPI
|
||||||
|
run: |
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
import torch, json, os, imageio
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from einops import rearrange
|
|
||||||
import lightning as pl
|
|
||||||
from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def lets_dance(
|
|
||||||
unet: SDUNet,
|
|
||||||
motion_modules: SDMotionModel,
|
|
||||||
sample,
|
|
||||||
timestep,
|
|
||||||
encoder_hidden_states,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
# 1. ControlNet (skip)
|
|
||||||
# 2. time
|
|
||||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
|
||||||
time_emb = unet.time_embedding(time_emb)
|
|
||||||
|
|
||||||
# 3. pre-process
|
|
||||||
hidden_states = unet.conv_in(sample)
|
|
||||||
text_emb = encoder_hidden_states
|
|
||||||
res_stack = [hidden_states]
|
|
||||||
|
|
||||||
# 4. blocks
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
for block_id, block in enumerate(unet.blocks):
|
|
||||||
# 4.1 UNet
|
|
||||||
if use_gradient_checkpointing:
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, time_emb, text_emb, res_stack,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
# 4.2 AnimateDiff
|
|
||||||
if block_id in motion_modules.call_block_id:
|
|
||||||
motion_module_id = motion_modules.call_block_id[block_id]
|
|
||||||
if use_gradient_checkpointing:
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(motion_modules.motion_modules[motion_module_id]),
|
|
||||||
hidden_states, time_emb, text_emb, res_stack,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 5. output
|
|
||||||
hidden_states = unet.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = unet.conv_act(hidden_states)
|
|
||||||
hidden_states = unet.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextVideoDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
|
|
||||||
with open(metadata_path, "r") as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
|
|
||||||
self.text = [i["text"] for i in metadata]
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
|
||||||
self.training_shapes = training_shapes
|
|
||||||
|
|
||||||
self.frame_process = []
|
|
||||||
for max_num_frames, interval, num_frames, height, width in training_shapes:
|
|
||||||
self.frame_process.append(v2.Compose([
|
|
||||||
v2.Resize(size=max(height, width), antialias=True),
|
|
||||||
v2.CenterCrop(size=(height, width)),
|
|
||||||
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
|
||||||
]))
|
|
||||||
|
|
||||||
|
|
||||||
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
|
||||||
reader = imageio.get_reader(file_path)
|
|
||||||
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
|
||||||
reader.close()
|
|
||||||
return None
|
|
||||||
|
|
||||||
frames = []
|
|
||||||
for frame_id in range(num_frames):
|
|
||||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
|
||||||
frame = torch.tensor(frame, dtype=torch.float32)
|
|
||||||
frame = rearrange(frame, "H W C -> 1 C H W")
|
|
||||||
frame = frame_process(frame)
|
|
||||||
frames.append(frame)
|
|
||||||
reader.close()
|
|
||||||
|
|
||||||
frames = torch.concat(frames, dim=0)
|
|
||||||
frames = rearrange(frames, "T C H W -> C T H W")
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def load_video(self, file_path, training_shape_id):
|
|
||||||
data = {}
|
|
||||||
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
|
|
||||||
frame_process = self.frame_process[training_shape_id]
|
|
||||||
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
|
|
||||||
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
|
|
||||||
if frames is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
data[f"frames_{training_shape_id}"] = frames
|
|
||||||
data[f"start_frame_id_{training_shape_id}"] = start_frame_id
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
video_data = {}
|
|
||||||
for training_shape_id in range(len(self.training_shapes)):
|
|
||||||
while True:
|
|
||||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
|
||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
|
||||||
text = self.text[data_id]
|
|
||||||
if isinstance(text, list):
|
|
||||||
text = text[torch.randint(0, len(text), (1,))[0]]
|
|
||||||
video_file = self.path[data_id]
|
|
||||||
try:
|
|
||||||
data = self.load_video(video_file, training_shape_id)
|
|
||||||
except:
|
|
||||||
data = None
|
|
||||||
if data is not None:
|
|
||||||
data[f"text_{training_shape_id}"] = text
|
|
||||||
break
|
|
||||||
video_data.update(data)
|
|
||||||
return video_data
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.steps_per_epoch
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LightningModel(pl.LightningModule):
|
|
||||||
def __init__(self, learning_rate=1e-5, sd_ckpt_path=None):
|
|
||||||
super().__init__()
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cpu")
|
|
||||||
model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path))
|
|
||||||
|
|
||||||
# Initialize motion modules
|
|
||||||
model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device)
|
|
||||||
|
|
||||||
# Build pipeline
|
|
||||||
self.pipe = SDVideoPipeline.from_model_manager(model_manager)
|
|
||||||
self.pipe.vae_encoder.eval()
|
|
||||||
self.pipe.vae_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
self.pipe.vae_decoder.eval()
|
|
||||||
self.pipe.vae_decoder.requires_grad_(False)
|
|
||||||
|
|
||||||
self.pipe.text_encoder.eval()
|
|
||||||
self.pipe.text_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
self.pipe.unet.eval()
|
|
||||||
self.pipe.unet.requires_grad_(False)
|
|
||||||
|
|
||||||
self.pipe.motion_modules.train()
|
|
||||||
self.pipe.motion_modules.requires_grad_(True)
|
|
||||||
|
|
||||||
# Reset the scheduler
|
|
||||||
self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
|
||||||
self.pipe.scheduler.set_timesteps(1000)
|
|
||||||
|
|
||||||
# Other parameters
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video_with_vae(self, video):
|
|
||||||
video = video.to(device=self.device, dtype=self.dtype)
|
|
||||||
video = video.unsqueeze(0)
|
|
||||||
latents = self.pipe.vae_encoder.encode_video(video, batch_size=16)
|
|
||||||
latents = rearrange(latents[0], "C T H W -> T C H W")
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_loss(self, prompt, frames):
|
|
||||||
with torch.no_grad():
|
|
||||||
# Call video encoder
|
|
||||||
latents = self.encode_video_with_vae(frames)
|
|
||||||
|
|
||||||
# Call text encoder
|
|
||||||
prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77)
|
|
||||||
prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1)
|
|
||||||
|
|
||||||
# Call scheduler
|
|
||||||
timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0]
|
|
||||||
noise = torch.randn_like(latents)
|
|
||||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
|
||||||
|
|
||||||
# Calculate loss
|
|
||||||
model_pred = lets_dance(
|
|
||||||
self.pipe.unet, self.pipe.motion_modules,
|
|
||||||
sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep
|
|
||||||
)
|
|
||||||
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
# Loss
|
|
||||||
frames = batch["frames_0"][0]
|
|
||||||
prompt = batch["text_0"][0]
|
|
||||||
loss = self.calculate_loss(prompt, frames)
|
|
||||||
|
|
||||||
# Record log
|
|
||||||
self.log("train_loss", loss, prog_bar=True)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
|
||||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters()))
|
|
||||||
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
|
|
||||||
checkpoint["trainable_param_names"] = trainable_param_names
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# dataset and data loader
|
|
||||||
dataset = TextVideoDataset(
|
|
||||||
"/data/zhongjie/datasets/opensoraplan/data/processed",
|
|
||||||
"/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json",
|
|
||||||
training_shapes=[(16, 1, 16, 512, 512)],
|
|
||||||
steps_per_epoch=7*10000,
|
|
||||||
)
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
shuffle=True,
|
|
||||||
batch_size=1,
|
|
||||||
num_workers=4
|
|
||||||
)
|
|
||||||
|
|
||||||
# model
|
|
||||||
model = LightningModel(
|
|
||||||
learning_rate=1e-5,
|
|
||||||
sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors",
|
|
||||||
)
|
|
||||||
|
|
||||||
# train
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
max_epochs=100000,
|
|
||||||
accelerator="gpu",
|
|
||||||
devices="auto",
|
|
||||||
strategy="deepspeed_stage_1",
|
|
||||||
precision="16-mixed",
|
|
||||||
default_root_dir="/data/zhongjie/models/train_extended_animatediff",
|
|
||||||
accumulate_grad_batches=1,
|
|
||||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
|
||||||
)
|
|
||||||
trainer.fit(
|
|
||||||
model=model,
|
|
||||||
train_dataloaders=train_loader,
|
|
||||||
ckpt_path=None
|
|
||||||
)
|
|
||||||
280
README.md
280
README.md
@@ -1,92 +1,196 @@
|
|||||||
# DiffSynth Studio
|
# DiffSynth Studio
|
||||||
|
[](https://pypi.org/project/DiffSynth/)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
Welcome to the magic world of Diffusion models!
|
||||||
|
|
||||||
## Roadmap
|
DiffSynth consists of two open-source projects:
|
||||||
|
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||||
|
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||||
|
|
||||||
|
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||||
|
|
||||||
|
Until now, DiffSynth-Studio has supported the following models:
|
||||||
|
|
||||||
|
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||||
|
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||||
|
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||||
|
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||||
|
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||||
|
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||||
|
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||||
|
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||||
|
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||||
|
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||||
|
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||||
|
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||||
|
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||||
|
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||||
|
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||||
|
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||||
|
|
||||||
|
## News
|
||||||
|
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
|
||||||
|
|
||||||
|
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||||
|
|
||||||
|
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||||
|
|
||||||
|
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||||
|
|
||||||
|
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||||
|
|
||||||
|
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||||
|
|
||||||
|
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||||
|
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||||
|
|
||||||
|
- **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
|
||||||
|
- Paper: https://arxiv.org/abs/2412.12888
|
||||||
|
- Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||||
|
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
|
||||||
|
|
||||||
|
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
|
||||||
|
|
||||||
|
- **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
|
||||||
|
|
||||||
|
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
|
||||||
|
- Text to video
|
||||||
|
- Video editing
|
||||||
|
- Self-upscaling
|
||||||
|
- Video interpolation
|
||||||
|
|
||||||
|
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
|
||||||
|
- Use it in our [WebUI](#usage-in-webui).
|
||||||
|
|
||||||
|
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
|
||||||
|
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||||
|
- LoRA, ControlNet, and additional models will be available soon.
|
||||||
|
|
||||||
|
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||||
|
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||||
|
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||||
|
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||||
|
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
||||||
|
|
||||||
|
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||||
|
|
||||||
|
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
|
||||||
|
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- The source codes are released in this project.
|
||||||
|
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||||
|
|
||||||
|
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||||
|
|
||||||
|
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
|
||||||
|
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||||
|
- Demo videos are shown on Bilibili, including three tasks.
|
||||||
|
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||||
|
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||||
|
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||||
|
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||||
|
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||||
|
|
||||||
|
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||||
|
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||||
|
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||||
|
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||||
|
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||||
|
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||||
|
- Since OLSS requires additional training, we don't implement it in this project.
|
||||||
|
|
||||||
|
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
|
||||||
|
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||||
|
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||||
|
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||||
|
|
||||||
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
|
|
||||||
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
|
||||||
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
|
||||||
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
|
||||||
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
|
||||||
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
|
||||||
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
|
||||||
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
|
||||||
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
|
||||||
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
|
||||||
* Since OLSS requires additional training, we don't implement it in this project.
|
|
||||||
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
|
|
||||||
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
|
||||||
* Demo videos are shown on Bilibili, including three tasks.
|
|
||||||
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
|
||||||
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
|
||||||
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
|
||||||
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
|
||||||
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
|
||||||
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
|
||||||
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
|
|
||||||
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
|
|
||||||
* The source codes are released in this project.
|
|
||||||
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
|
||||||
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
|
||||||
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
|
||||||
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
|
|
||||||
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
|
||||||
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
|
||||||
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
|
||||||
* Until now, DiffSynth Studio has supported the following models:
|
|
||||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
|
||||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
|
||||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
|
||||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
|
||||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
|
||||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
|
||||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
|
||||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
|
||||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
|
||||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Create Python environment:
|
Install from source code (recommended):
|
||||||
|
|
||||||
```
|
```
|
||||||
conda env create -f environment.yml
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
|
||||||
|
|
||||||
Enter the Python environment:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
conda activate DiffSynthStudio
|
pip install diffsynth
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you encounter issues during installation, it may be caused by the packages we depend on. Please refer to the documentation of the package that caused the problem.
|
||||||
|
|
||||||
|
* [torch](https://pytorch.org/get-started/locally/)
|
||||||
|
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||||
|
* [cmake](https://cmake.org)
|
||||||
|
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||||
|
|
||||||
## Usage (in Python code)
|
## Usage (in Python code)
|
||||||
|
|
||||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||||
|
|
||||||
### Long Video Synthesis
|
### Download Models
|
||||||
|
|
||||||
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth import download_models
|
||||||
|
|
||||||
|
download_models(["FLUX.1-dev", "Kolors"])
|
||||||
|
```
|
||||||
|
|
||||||
|
Download your own models.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
|
||||||
|
|
||||||
|
# From Modelscope (recommended)
|
||||||
|
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
|
||||||
|
# From Huggingface
|
||||||
|
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Video Synthesis
|
||||||
|
|
||||||
|
#### Text-to-video using CogVideoX-5B
|
||||||
|
|
||||||
|
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
|
||||||
|
|
||||||
|
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||||
|
|
||||||
|
#### Long Video Synthesis
|
||||||
|
|
||||||
|
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||||
|
|
||||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||||
|
|
||||||
### Image Synthesis
|
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
||||||
|
|
||||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|
#### Toon Shading
|
||||||
|
|
||||||
|512*512|1024*1024|2048*2048|4096*4096|
|
|
||||||
|-|-|-|-|
|
|
||||||
|||||
|
|
||||||
|
|
||||||
|1024*1024|2048*2048|
|
|
||||||
|-|-|
|
|
||||||
|||
|
|
||||||
|
|
||||||
### Toon Shading
|
|
||||||
|
|
||||||
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
||||||
|
|
||||||
@@ -94,32 +198,60 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
|||||||
|
|
||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||||
|
|
||||||
### Video Stylization
|
#### Video Stylization
|
||||||
|
|
||||||
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
||||||
|
|
||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
### Chinese Models
|
### Image Synthesis
|
||||||
|
|
||||||
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
|
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
||||||
|
|
||||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||||
|
|
||||||
|1024x1024|2048x2048 (highres-fix)|
|
|FLUX|Stable Diffusion 3|
|
||||||
|-|-|
|
|-|-|
|
||||||
|||
|
|||
|
||||||
|
|
||||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
|Kolors|Hunyuan-DiT|
|
||||||
|
|
||||||
|Without LoRA|With LoRA|
|
|
||||||
|-|-|
|
|-|-|
|
||||||
|||
|
|||
|
||||||
|
|
||||||
|
|Stable Diffusion|Stable Diffusion XL|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
## Usage (in WebUI)
|
## Usage (in WebUI)
|
||||||
|
|
||||||
|
Create stunning images using the painter, with assistance from AI!
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
|
||||||
|
|
||||||
|
**This video is not rendered in real-time.**
|
||||||
|
|
||||||
|
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
|
||||||
|
|
||||||
|
* `Gradio` version
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m streamlit run DiffSynth_Studio.py
|
pip install gradio
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python apps/gradio/DiffSynth_Studio.py
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
* `Streamlit` version
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install streamlit streamlit-drawable-canvas
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
|
||||||
```
|
```
|
||||||
|
|
||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
||||||
|
|||||||
252
apps/gradio/DiffSynth_Studio.py
Normal file
252
apps/gradio/DiffSynth_Studio.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
||||||
|
import os, torch
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"model_config": {
|
||||||
|
"Stable Diffusion": {
|
||||||
|
"model_folder": "models/stable_diffusion",
|
||||||
|
"pipeline_class": SDImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
"height": 512,
|
||||||
|
"width": 512,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Stable Diffusion XL": {
|
||||||
|
"model_folder": "models/stable_diffusion_xl",
|
||||||
|
"pipeline_class": SDXLImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Stable Diffusion 3": {
|
||||||
|
"model_folder": "models/stable_diffusion_3",
|
||||||
|
"pipeline_class": SD3ImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Stable Diffusion XL Turbo": {
|
||||||
|
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||||
|
"pipeline_class": SDXLImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"negative_prompt": "",
|
||||||
|
"cfg_scale": 1.0,
|
||||||
|
"num_inference_steps": 1,
|
||||||
|
"height": 512,
|
||||||
|
"width": 512,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Kolors": {
|
||||||
|
"model_folder": "models/kolors",
|
||||||
|
"pipeline_class": SDXLImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"HunyuanDiT": {
|
||||||
|
"model_folder": "models/HunyuanDiT",
|
||||||
|
"pipeline_class": HunyuanDiTImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"FLUX": {
|
||||||
|
"model_folder": "models/FLUX",
|
||||||
|
"pipeline_class": FluxImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"max_num_painter_layers": 8,
|
||||||
|
"max_num_model_cache": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_list(model_type):
|
||||||
|
if model_type is None:
|
||||||
|
return []
|
||||||
|
folder = config["model_config"][model_type]["model_folder"]
|
||||||
|
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||||
|
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||||
|
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||||
|
file_list = sorted(file_list)
|
||||||
|
return file_list
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_type, model_path):
|
||||||
|
global model_dict
|
||||||
|
model_key = f"{model_type}:{model_path}"
|
||||||
|
if model_key in model_dict:
|
||||||
|
return model_dict[model_key]
|
||||||
|
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||||
|
model_manager = ModelManager()
|
||||||
|
if model_type == "HunyuanDiT":
|
||||||
|
model_manager.load_models([
|
||||||
|
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
||||||
|
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
||||||
|
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||||
|
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||||
|
])
|
||||||
|
elif model_type == "Kolors":
|
||||||
|
model_manager.load_models([
|
||||||
|
os.path.join(model_path, "text_encoder"),
|
||||||
|
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
])
|
||||||
|
elif model_type == "FLUX":
|
||||||
|
model_manager.torch_dtype = torch.bfloat16
|
||||||
|
file_list = [
|
||||||
|
os.path.join(model_path, "text_encoder/model.safetensors"),
|
||||||
|
os.path.join(model_path, "text_encoder_2"),
|
||||||
|
]
|
||||||
|
for file_name in os.listdir(model_path):
|
||||||
|
if file_name.endswith(".safetensors"):
|
||||||
|
file_list.append(os.path.join(model_path, file_name))
|
||||||
|
model_manager.load_models(file_list)
|
||||||
|
else:
|
||||||
|
model_manager.load_model(model_path)
|
||||||
|
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||||
|
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
||||||
|
key = next(iter(model_dict.keys()))
|
||||||
|
model_manager_to_release, _ = model_dict[key]
|
||||||
|
model_manager_to_release.to("cpu")
|
||||||
|
del model_dict[key]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model_dict[model_key] = model_manager, pipe
|
||||||
|
return model_manager, pipe
|
||||||
|
|
||||||
|
|
||||||
|
model_dict = {}
|
||||||
|
|
||||||
|
with gr.Blocks() as app:
|
||||||
|
gr.Markdown("# DiffSynth-Studio Painter")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=382, min_width=100):
|
||||||
|
|
||||||
|
with gr.Accordion(label="Model"):
|
||||||
|
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
||||||
|
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
||||||
|
|
||||||
|
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
||||||
|
def model_type_to_model_path(model_type):
|
||||||
|
return gr.Dropdown(choices=load_model_list(model_type))
|
||||||
|
|
||||||
|
with gr.Accordion(label="Prompt"):
|
||||||
|
prompt = gr.Textbox(label="Prompt", lines=3)
|
||||||
|
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
||||||
|
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||||
|
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
||||||
|
|
||||||
|
with gr.Accordion(label="Image"):
|
||||||
|
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
||||||
|
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||||
|
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||||
|
with gr.Column():
|
||||||
|
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
||||||
|
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
||||||
|
|
||||||
|
@gr.on(
|
||||||
|
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
||||||
|
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
||||||
|
triggers=model_path.change
|
||||||
|
)
|
||||||
|
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
||||||
|
load_model(model_type, model_path)
|
||||||
|
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
||||||
|
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
||||||
|
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
||||||
|
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
||||||
|
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
||||||
|
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Column(scale=618, min_width=100):
|
||||||
|
with gr.Accordion(label="Painter"):
|
||||||
|
enable_local_prompt_list = []
|
||||||
|
local_prompt_list = []
|
||||||
|
mask_scale_list = []
|
||||||
|
canvas_list = []
|
||||||
|
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||||
|
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
||||||
|
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
||||||
|
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||||
|
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
||||||
|
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
||||||
|
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
||||||
|
label="Painter", key=f"canvas_{painter_layer_id}")
|
||||||
|
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
||||||
|
def resize_canvas(height, width, canvas):
|
||||||
|
h, w = canvas["background"].shape[:2]
|
||||||
|
if h != height or width != w:
|
||||||
|
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||||
|
else:
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
enable_local_prompt_list.append(enable_local_prompt)
|
||||||
|
local_prompt_list.append(local_prompt)
|
||||||
|
mask_scale_list.append(mask_scale)
|
||||||
|
canvas_list.append(canvas)
|
||||||
|
with gr.Accordion(label="Results"):
|
||||||
|
run_button = gr.Button(value="Generate", variant="primary")
|
||||||
|
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||||
|
with gr.Column():
|
||||||
|
output_to_input_button = gr.Button(value="Set as input image")
|
||||||
|
painter_background = gr.State(None)
|
||||||
|
input_background = gr.State(None)
|
||||||
|
@gr.on(
|
||||||
|
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
||||||
|
outputs=[output_image],
|
||||||
|
triggers=run_button.click
|
||||||
|
)
|
||||||
|
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
||||||
|
_, pipe = load_model(model_type, model_path)
|
||||||
|
input_params = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
"progress_bar_cmd": progress.tqdm,
|
||||||
|
}
|
||||||
|
if isinstance(pipe, FluxImagePipeline):
|
||||||
|
input_params["embedded_guidance"] = embedded_guidance
|
||||||
|
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
||||||
|
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||||
|
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||||
|
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
||||||
|
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
||||||
|
)
|
||||||
|
local_prompts, masks, mask_scales = [], [], []
|
||||||
|
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
||||||
|
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
||||||
|
):
|
||||||
|
if enable_local_prompt:
|
||||||
|
local_prompts.append(local_prompt)
|
||||||
|
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||||
|
mask_scales.append(mask_scale)
|
||||||
|
input_params.update({
|
||||||
|
"local_prompts": local_prompts,
|
||||||
|
"masks": masks,
|
||||||
|
"mask_scales": mask_scales,
|
||||||
|
})
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image = pipe(**input_params)
|
||||||
|
return image
|
||||||
|
|
||||||
|
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||||
|
def send_output_to_painter_background(output_image, *canvas_list):
|
||||||
|
for canvas in canvas_list:
|
||||||
|
h, w = canvas["background"].shape[:2]
|
||||||
|
canvas["background"] = output_image.resize((w, h))
|
||||||
|
return tuple(canvas_list)
|
||||||
|
app.launch()
|
||||||
390
apps/gradio/entity_level_control.py
Normal file
390
apps/gradio/entity_level_control.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
import gradio as gr
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
||||||
|
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
||||||
|
with open(example_json, 'r') as f:
|
||||||
|
examples = json.load(f)['examples']
|
||||||
|
|
||||||
|
for idx in range(len(examples)):
|
||||||
|
example_id = examples[idx]['example_id']
|
||||||
|
entity_prompts = examples[idx]['local_prompt_list']
|
||||||
|
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||||
|
|
||||||
|
def create_canvas_data(background, masks):
|
||||||
|
if background.shape[-1] == 3:
|
||||||
|
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
||||||
|
layers = []
|
||||||
|
for mask in masks:
|
||||||
|
if mask is not None:
|
||||||
|
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
||||||
|
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
||||||
|
layer[..., -1] = mask_single_channel
|
||||||
|
layers.append(layer)
|
||||||
|
else:
|
||||||
|
layers.append(np.zeros_like(background))
|
||||||
|
|
||||||
|
composite = background.copy()
|
||||||
|
for layer in layers:
|
||||||
|
if layer.size > 0:
|
||||||
|
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
||||||
|
return {
|
||||||
|
"background": background,
|
||||||
|
"layers": layers,
|
||||||
|
"composite": composite,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_example(load_example_button):
|
||||||
|
example_idx = int(load_example_button.split()[-1]) - 1
|
||||||
|
example = examples[example_idx]
|
||||||
|
result = [
|
||||||
|
50,
|
||||||
|
example["global_prompt"],
|
||||||
|
example["negative_prompt"],
|
||||||
|
example["seed"],
|
||||||
|
*example["local_prompt_list"],
|
||||||
|
]
|
||||||
|
num_entities = len(example["local_prompt_list"])
|
||||||
|
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
||||||
|
masks = []
|
||||||
|
for mask in example["mask_lists"]:
|
||||||
|
mask_single_channel = np.array(mask.convert("L"))
|
||||||
|
masks.append(mask_single_channel)
|
||||||
|
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
||||||
|
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
||||||
|
masks.append(blank_mask)
|
||||||
|
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
||||||
|
canvas_data_list = []
|
||||||
|
for mask in masks:
|
||||||
|
canvas_data = create_canvas_data(background, [mask])
|
||||||
|
canvas_data_list.append(canvas_data)
|
||||||
|
result.extend(canvas_data_list)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||||
|
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||||
|
print(f'save to {save_dir}')
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
for i, mask in enumerate(masks):
|
||||||
|
save_path = os.path.join(save_dir, f'{i}.png')
|
||||||
|
mask.save(save_path)
|
||||||
|
sample = {
|
||||||
|
"global_prompt": global_prompt,
|
||||||
|
"mask_prompts": mask_prompts,
|
||||||
|
"seed": seed,
|
||||||
|
}
|
||||||
|
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
||||||
|
json.dump(sample, f, indent=4)
|
||||||
|
|
||||||
|
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
||||||
|
# Create a blank image for overlays
|
||||||
|
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||||
|
colors = [
|
||||||
|
(165, 238, 173, 80),
|
||||||
|
(76, 102, 221, 80),
|
||||||
|
(221, 160, 77, 80),
|
||||||
|
(204, 93, 71, 80),
|
||||||
|
(145, 187, 149, 80),
|
||||||
|
(134, 141, 172, 80),
|
||||||
|
(157, 137, 109, 80),
|
||||||
|
(153, 104, 95, 80),
|
||||||
|
(165, 238, 173, 80),
|
||||||
|
(76, 102, 221, 80),
|
||||||
|
(221, 160, 77, 80),
|
||||||
|
(204, 93, 71, 80),
|
||||||
|
(145, 187, 149, 80),
|
||||||
|
(134, 141, 172, 80),
|
||||||
|
(157, 137, 109, 80),
|
||||||
|
(153, 104, 95, 80),
|
||||||
|
]
|
||||||
|
# Generate random colors for each mask
|
||||||
|
if use_random_colors:
|
||||||
|
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||||
|
# Font settings
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||||
|
except IOError:
|
||||||
|
font = ImageFont.load_default(font_size)
|
||||||
|
# Overlay each mask onto the overlay image
|
||||||
|
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||||
|
if mask is None:
|
||||||
|
continue
|
||||||
|
# Convert mask to RGBA mode
|
||||||
|
mask_rgba = mask.convert('RGBA')
|
||||||
|
mask_data = mask_rgba.getdata()
|
||||||
|
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||||
|
mask_rgba.putdata(new_data)
|
||||||
|
# Draw the mask prompt text on the mask
|
||||||
|
draw = ImageDraw.Draw(mask_rgba)
|
||||||
|
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||||
|
if mask_bbox is None:
|
||||||
|
continue
|
||||||
|
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||||
|
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||||
|
# Alpha composite the overlay with this mask
|
||||||
|
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||||
|
# Composite the overlay onto the original image
|
||||||
|
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||||
|
return result
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"model_config": {
|
||||||
|
"FLUX": {
|
||||||
|
"model_folder": "models/FLUX",
|
||||||
|
"pipeline_class": FluxImagePipeline,
|
||||||
|
"default_parameters": {
|
||||||
|
"cfg_scale": 3.0,
|
||||||
|
"embedded_guidance": 3.5,
|
||||||
|
"num_inference_steps": 30,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"max_num_painter_layers": 8,
|
||||||
|
"max_num_model_cache": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_dict = {}
|
||||||
|
|
||||||
|
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
||||||
|
global model_dict
|
||||||
|
model_key = f"{model_type}:{model_path}"
|
||||||
|
if model_key in model_dict:
|
||||||
|
return model_dict[model_key]
|
||||||
|
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
model_manager.load_lora(
|
||||||
|
download_customized_models(
|
||||||
|
model_id="DiffSynth-Studio/Eligen",
|
||||||
|
origin_file_path="model_bf16.safetensors",
|
||||||
|
local_dir="models/lora/entity_control",
|
||||||
|
),
|
||||||
|
lora_alpha=1,
|
||||||
|
)
|
||||||
|
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||||
|
model_dict[model_key] = model_manager, pipe
|
||||||
|
return model_manager, pipe
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as app:
|
||||||
|
gr.Markdown(
|
||||||
|
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
||||||
|
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
||||||
|
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
||||||
|
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
||||||
|
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
||||||
|
main_interface = gr.Column(visible=False)
|
||||||
|
|
||||||
|
def initialize_model():
|
||||||
|
try:
|
||||||
|
load_model()
|
||||||
|
return {
|
||||||
|
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
||||||
|
main_interface: gr.update(visible=True),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Failed to load model with error: {e}')
|
||||||
|
return {
|
||||||
|
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
||||||
|
main_interface: gr.update(visible=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
||||||
|
|
||||||
|
with main_interface:
|
||||||
|
with gr.Row():
|
||||||
|
local_prompt_list = []
|
||||||
|
canvas_list = []
|
||||||
|
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
||||||
|
with gr.Column(scale=382, min_width=100):
|
||||||
|
model_type = gr.State('FLUX')
|
||||||
|
model_path = gr.State('FLUX.1-dev')
|
||||||
|
with gr.Accordion(label="Global prompt"):
|
||||||
|
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
||||||
|
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
||||||
|
with gr.Accordion(label="Inference Options", open=True):
|
||||||
|
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
||||||
|
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
||||||
|
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||||
|
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
||||||
|
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||||
|
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||||
|
with gr.Accordion(label="Inpaint Input Image", open=False):
|
||||||
|
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
||||||
|
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
||||||
|
send_input_to_painter = gr.Button(value="Set as painter's background")
|
||||||
|
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
||||||
|
def reset_input_image(input_image):
|
||||||
|
return None
|
||||||
|
|
||||||
|
with gr.Column(scale=618, min_width=100):
|
||||||
|
with gr.Accordion(label="Entity Painter"):
|
||||||
|
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||||
|
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
||||||
|
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||||
|
canvas = gr.ImageEditor(
|
||||||
|
canvas_size=(512, 512),
|
||||||
|
sources=None,
|
||||||
|
layers=False,
|
||||||
|
interactive=True,
|
||||||
|
image_mode="RGBA",
|
||||||
|
brush=gr.Brush(
|
||||||
|
default_size=50,
|
||||||
|
default_color="#000000",
|
||||||
|
colors=["#000000"],
|
||||||
|
),
|
||||||
|
label="Entity Mask Painter",
|
||||||
|
key=f"canvas_{painter_layer_id}",
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
)
|
||||||
|
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
||||||
|
def resize_canvas(height, width, canvas):
|
||||||
|
h, w = canvas["background"].shape[:2]
|
||||||
|
if h != height or width != w:
|
||||||
|
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||||
|
else:
|
||||||
|
return canvas
|
||||||
|
local_prompt_list.append(local_prompt)
|
||||||
|
canvas_list.append(canvas)
|
||||||
|
with gr.Accordion(label="Results"):
|
||||||
|
run_button = gr.Button(value="Generate", variant="primary")
|
||||||
|
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||||
|
with gr.Column():
|
||||||
|
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
||||||
|
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
||||||
|
real_output = gr.State(None)
|
||||||
|
mask_out = gr.State(None)
|
||||||
|
|
||||||
|
@gr.on(
|
||||||
|
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
||||||
|
outputs=[output_image, real_output, mask_out],
|
||||||
|
triggers=run_button.click
|
||||||
|
)
|
||||||
|
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||||
|
_, pipe = load_model(model_type, model_path)
|
||||||
|
input_params = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
"progress_bar_cmd": progress.tqdm,
|
||||||
|
}
|
||||||
|
if isinstance(pipe, FluxImagePipeline):
|
||||||
|
input_params["embedded_guidance"] = embedded_guidance
|
||||||
|
if input_image is not None:
|
||||||
|
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||||
|
input_params["enable_eligen_inpaint"] = True
|
||||||
|
|
||||||
|
local_prompt_list, canvas_list = (
|
||||||
|
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||||
|
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||||
|
)
|
||||||
|
local_prompts, masks = [], []
|
||||||
|
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
||||||
|
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
||||||
|
local_prompts.append(local_prompt)
|
||||||
|
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||||
|
entity_masks = None if len(masks) == 0 else masks
|
||||||
|
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||||
|
input_params.update({
|
||||||
|
"eligen_entity_prompts": entity_prompts,
|
||||||
|
"eligen_entity_masks": entity_masks,
|
||||||
|
})
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
||||||
|
image = pipe(**input_params)
|
||||||
|
masks = [mask.resize(image.size) for mask in masks]
|
||||||
|
image_with_mask = visualize_masks(image, masks, local_prompts)
|
||||||
|
|
||||||
|
real_output = gr.State(image)
|
||||||
|
mask_out = gr.State(image_with_mask)
|
||||||
|
|
||||||
|
if return_with_mask:
|
||||||
|
return image_with_mask, real_output, mask_out
|
||||||
|
return image, real_output, mask_out
|
||||||
|
|
||||||
|
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
||||||
|
def send_input_to_painter_background(input_image, *canvas_list):
|
||||||
|
if input_image is None:
|
||||||
|
return tuple(canvas_list)
|
||||||
|
for canvas in canvas_list:
|
||||||
|
h, w = canvas["background"].shape[:2]
|
||||||
|
canvas["background"] = input_image.resize((w, h))
|
||||||
|
return tuple(canvas_list)
|
||||||
|
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||||
|
def send_output_to_painter_background(real_output, *canvas_list):
|
||||||
|
if real_output is None:
|
||||||
|
return tuple(canvas_list)
|
||||||
|
for canvas in canvas_list:
|
||||||
|
h, w = canvas["background"].shape[:2]
|
||||||
|
canvas["background"] = real_output.value.resize((w, h))
|
||||||
|
return tuple(canvas_list)
|
||||||
|
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
||||||
|
def show_output(return_with_mask, real_output, mask_out):
|
||||||
|
if return_with_mask:
|
||||||
|
return mask_out.value
|
||||||
|
else:
|
||||||
|
return real_output.value
|
||||||
|
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
||||||
|
def send_output_to_pipe_input(real_output):
|
||||||
|
return real_output.value
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("## Examples")
|
||||||
|
for i in range(0, len(examples), 2):
|
||||||
|
with gr.Row():
|
||||||
|
if i < len(examples):
|
||||||
|
example = examples[i]
|
||||||
|
with gr.Column():
|
||||||
|
example_image = gr.Image(
|
||||||
|
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
||||||
|
label=example["description"],
|
||||||
|
interactive=False,
|
||||||
|
width=1024,
|
||||||
|
height=512
|
||||||
|
)
|
||||||
|
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||||
|
load_example_button.click(
|
||||||
|
load_example,
|
||||||
|
inputs=[load_example_button],
|
||||||
|
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||||
|
)
|
||||||
|
|
||||||
|
if i + 1 < len(examples):
|
||||||
|
example = examples[i + 1]
|
||||||
|
with gr.Column():
|
||||||
|
example_image = gr.Image(
|
||||||
|
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
||||||
|
label=example["description"],
|
||||||
|
interactive=False,
|
||||||
|
width=1024,
|
||||||
|
height=512
|
||||||
|
)
|
||||||
|
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||||
|
load_example_button.click(
|
||||||
|
load_example,
|
||||||
|
inputs=[load_example_button],
|
||||||
|
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||||
|
)
|
||||||
|
app.config["show_progress"] = "hidden"
|
||||||
|
app.launch()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
# Set web page format
|
# Set web page format
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
st.set_page_config(layout="wide")
|
st.set_page_config(layout="wide")
|
||||||
# Diasble virtual VRAM on windows system
|
# Disable virtual VRAM on windows system
|
||||||
import torch
|
import torch
|
||||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||||
|
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
import torch, os, io
|
import torch, os, io, json, time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
st.set_page_config(layout="wide")
|
st.set_page_config(layout="wide")
|
||||||
from streamlit_drawable_canvas import st_canvas
|
from streamlit_drawable_canvas import st_canvas
|
||||||
from diffsynth.models import ModelManager
|
from diffsynth.models import ModelManager
|
||||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, HunyuanDiTImagePipeline
|
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
||||||
from diffsynth.data.video import crop_and_resize
|
from diffsynth.data.video import crop_and_resize
|
||||||
|
|
||||||
|
|
||||||
@@ -20,6 +20,11 @@ config = {
|
|||||||
"pipeline_class": SDXLImagePipeline,
|
"pipeline_class": SDXLImagePipeline,
|
||||||
"fixed_parameters": {}
|
"fixed_parameters": {}
|
||||||
},
|
},
|
||||||
|
"Stable Diffusion 3": {
|
||||||
|
"model_folder": "models/stable_diffusion_3",
|
||||||
|
"pipeline_class": SD3ImagePipeline,
|
||||||
|
"fixed_parameters": {}
|
||||||
|
},
|
||||||
"Stable Diffusion XL Turbo": {
|
"Stable Diffusion XL Turbo": {
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||||
"pipeline_class": SDXLImagePipeline,
|
"pipeline_class": SDXLImagePipeline,
|
||||||
@@ -31,6 +36,11 @@ config = {
|
|||||||
"width": 512,
|
"width": 512,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"Kolors": {
|
||||||
|
"model_folder": "models/kolors",
|
||||||
|
"pipeline_class": SDXLImagePipeline,
|
||||||
|
"fixed_parameters": {}
|
||||||
|
},
|
||||||
"HunyuanDiT": {
|
"HunyuanDiT": {
|
||||||
"model_folder": "models/HunyuanDiT",
|
"model_folder": "models/HunyuanDiT",
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
"pipeline_class": HunyuanDiTImagePipeline,
|
||||||
@@ -39,13 +49,20 @@ config = {
|
|||||||
"width": 1024,
|
"width": 1024,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"FLUX": {
|
||||||
|
"model_folder": "models/FLUX",
|
||||||
|
"pipeline_class": FluxImagePipeline,
|
||||||
|
"fixed_parameters": {
|
||||||
|
"cfg_scale": 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
def load_model_list(model_type):
|
||||||
folder = config[model_type]["model_folder"]
|
folder = config[model_type]["model_folder"]
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||||
if model_type == "HunyuanDiT":
|
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||||
file_list = sorted(file_list)
|
file_list = sorted(file_list)
|
||||||
return file_list
|
return file_list
|
||||||
@@ -69,6 +86,22 @@ def load_model(model_type, model_path):
|
|||||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||||
])
|
])
|
||||||
|
elif model_type == "Kolors":
|
||||||
|
model_manager.load_models([
|
||||||
|
os.path.join(model_path, "text_encoder"),
|
||||||
|
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
])
|
||||||
|
elif model_type == "FLUX":
|
||||||
|
model_manager.torch_dtype = torch.bfloat16
|
||||||
|
file_list = [
|
||||||
|
os.path.join(model_path, "text_encoder/model.safetensors"),
|
||||||
|
os.path.join(model_path, "text_encoder_2"),
|
||||||
|
]
|
||||||
|
for file_name in os.listdir(model_path):
|
||||||
|
if file_name.endswith(".safetensors"):
|
||||||
|
file_list.append(os.path.join(model_path, file_name))
|
||||||
|
model_manager.load_models(file_list)
|
||||||
else:
|
else:
|
||||||
model_manager.load_model(model_path)
|
model_manager.load_model(model_path)
|
||||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||||
@@ -239,6 +272,48 @@ with column_input:
|
|||||||
key="canvas"
|
key="canvas"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
||||||
|
local_prompts, masks, mask_scales = [], [], []
|
||||||
|
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
||||||
|
painter_layers_json_data = []
|
||||||
|
for painter_tab_id in range(num_painter_layer):
|
||||||
|
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
||||||
|
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
||||||
|
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
||||||
|
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
||||||
|
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
||||||
|
canvas_result_local = st_canvas(
|
||||||
|
fill_color="#000000",
|
||||||
|
stroke_width=stroke_width,
|
||||||
|
stroke_color="#000000",
|
||||||
|
background_color="rgba(255, 255, 255, 0)",
|
||||||
|
background_image=white_board,
|
||||||
|
update_streamlit=True,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
drawing_mode="freedraw",
|
||||||
|
key=f"canvas_{painter_tab_id}"
|
||||||
|
)
|
||||||
|
if canvas_result_local.json_data is not None:
|
||||||
|
painter_layers_json_data.append(canvas_result_local.json_data.copy())
|
||||||
|
painter_layers_json_data[-1]["prompt"] = local_prompt
|
||||||
|
if enable_local_prompt:
|
||||||
|
local_prompts.append(local_prompt)
|
||||||
|
if canvas_result_local.image_data is not None:
|
||||||
|
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
||||||
|
else:
|
||||||
|
mask = white_board
|
||||||
|
mask = Image.fromarray(255 - np.array(mask))
|
||||||
|
masks.append(mask)
|
||||||
|
mask_scales.append(mask_scale)
|
||||||
|
save_painter_layers = st.button("Save painter layers")
|
||||||
|
if save_painter_layers:
|
||||||
|
os.makedirs("data/painter_layers", exist_ok=True)
|
||||||
|
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
|
||||||
|
with open(json_file_path, "w") as f:
|
||||||
|
json.dump(painter_layers_json_data, f, indent=4)
|
||||||
|
st.markdown(f"Painter layers are saved in {json_file_path}.")
|
||||||
|
|
||||||
|
|
||||||
with column_output:
|
with column_output:
|
||||||
run_button = st.button("Generate image", type="primary")
|
run_button = st.button("Generate image", type="primary")
|
||||||
@@ -266,6 +341,7 @@ with column_output:
|
|||||||
progress_bar_st = st.progress(0.0)
|
progress_bar_st = st.progress(0.0)
|
||||||
image = pipeline(
|
image = pipeline(
|
||||||
prompt, negative_prompt=negative_prompt,
|
prompt, negative_prompt=negative_prompt,
|
||||||
|
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
||||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
||||||
height=height, width=width,
|
height=height, width=width,
|
||||||
input_image=input_image, denoising_strength=denoising_strength,
|
input_image=input_image, denoising_strength=denoising_strength,
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from .data import *
|
from .data import *
|
||||||
from .models import *
|
from .models import *
|
||||||
from .prompts import *
|
from .prompters import *
|
||||||
from .schedulers import *
|
from .schedulers import *
|
||||||
from .pipelines import *
|
from .pipelines import *
|
||||||
from .controlnets import *
|
from .controlnets import *
|
||||||
|
|||||||
0
diffsynth/configs/__init__.py
Normal file
0
diffsynth/configs/__init__.py
Normal file
823
diffsynth/configs/model_config.py
Normal file
823
diffsynth/configs/model_config.py
Normal file
@@ -0,0 +1,823 @@
|
|||||||
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
|
from ..models.sd_text_encoder import SDTextEncoder
|
||||||
|
from ..models.sd_unet import SDUNet
|
||||||
|
from ..models.sd_vae_encoder import SDVAEEncoder
|
||||||
|
from ..models.sd_vae_decoder import SDVAEDecoder
|
||||||
|
|
||||||
|
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
|
from ..models.sdxl_unet import SDXLUNet
|
||||||
|
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||||
|
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||||
|
|
||||||
|
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||||
|
from ..models.sd3_dit import SD3DiT
|
||||||
|
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
||||||
|
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
||||||
|
|
||||||
|
from ..models.sd_controlnet import SDControlNet
|
||||||
|
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
||||||
|
|
||||||
|
from ..models.sd_motion import SDMotionModel
|
||||||
|
from ..models.sdxl_motion import SDXLMotionModel
|
||||||
|
|
||||||
|
from ..models.svd_image_encoder import SVDImageEncoder
|
||||||
|
from ..models.svd_unet import SVDUNet
|
||||||
|
from ..models.svd_vae_decoder import SVDVAEDecoder
|
||||||
|
from ..models.svd_vae_encoder import SVDVAEEncoder
|
||||||
|
|
||||||
|
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||||
|
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||||
|
|
||||||
|
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||||
|
from ..models.hunyuan_dit import HunyuanDiT
|
||||||
|
|
||||||
|
from ..models.flux_dit import FluxDiT
|
||||||
|
from ..models.flux_text_encoder import FluxTextEncoder2
|
||||||
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
|
from ..models.flux_controlnet import FluxControlNet
|
||||||
|
from ..models.flux_ipadapter import FluxIpAdapter
|
||||||
|
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||||
|
|
||||||
|
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||||
|
from ..models.cog_dit import CogDiT
|
||||||
|
|
||||||
|
from ..models.omnigen import OmniGenTransformer
|
||||||
|
|
||||||
|
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||||
|
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||||
|
|
||||||
|
from ..extensions.RIFE import IFNet
|
||||||
|
from ..extensions.ESRGAN import RRDBNet
|
||||||
|
|
||||||
|
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||||
|
|
||||||
|
from ..models.stepvideo_vae import StepVideoVAE
|
||||||
|
from ..models.stepvideo_dit import StepVideoModel
|
||||||
|
|
||||||
|
from ..models.wan_video_dit import WanModel
|
||||||
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
|
|
||||||
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
|
from ..models.flux_value_control import SingleValueEncoder
|
||||||
|
|
||||||
|
|
||||||
|
model_loader_configs = [
|
||||||
|
# These configs are provided for detecting model type automatically.
|
||||||
|
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||||
|
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
||||||
|
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
||||||
|
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
||||||
|
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
||||||
|
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
||||||
|
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
||||||
|
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
||||||
|
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
||||||
|
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
||||||
|
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
||||||
|
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||||
|
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
||||||
|
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||||
|
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
||||||
|
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
|
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
|
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
||||||
|
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
||||||
|
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
||||||
|
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
||||||
|
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
||||||
|
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
||||||
|
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||||
|
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||||
|
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
||||||
|
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||||
|
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
||||||
|
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
||||||
|
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||||
|
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||||
|
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||||
|
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||||
|
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||||
|
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
||||||
|
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "43ad5aaa27dd4ee01b832ed16773fa52", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||||
|
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||||
|
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
|
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||||
|
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
|
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||||
|
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
||||||
|
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||||
|
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||||
|
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
||||||
|
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
||||||
|
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
|
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
|
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||||
|
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||||
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
|
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
|
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||||
|
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||||
|
]
|
||||||
|
huggingface_model_loader_configs = [
|
||||||
|
# These configs are provided for detecting model type automatically.
|
||||||
|
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
||||||
|
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
||||||
|
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||||
|
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||||
|
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
||||||
|
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||||
|
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||||
|
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||||
|
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||||
|
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||||
|
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||||
|
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||||
|
("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
|
||||||
|
]
|
||||||
|
patch_model_loader_configs = [
|
||||||
|
# These configs are provided for detecting model type automatically.
|
||||||
|
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
||||||
|
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
||||||
|
]
|
||||||
|
|
||||||
|
preset_models_on_huggingface = {
|
||||||
|
"HunyuanDiT": [
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||||
|
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||||
|
],
|
||||||
|
"stable-video-diffusion-img2vid-xt": [
|
||||||
|
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
"ExVideo-SVD-128f-v1": [
|
||||||
|
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion
|
||||||
|
"StableDiffusion_v15": [
|
||||||
|
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||||
|
],
|
||||||
|
"DreamShaper_8": [
|
||||||
|
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||||
|
],
|
||||||
|
# Textual Inversion
|
||||||
|
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||||
|
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion XL
|
||||||
|
"StableDiffusionXL_v1": [
|
||||||
|
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||||
|
],
|
||||||
|
"BluePencilXL_v200": [
|
||||||
|
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||||
|
],
|
||||||
|
"StableDiffusionXL_Turbo": [
|
||||||
|
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion 3
|
||||||
|
"StableDiffusion3": [
|
||||||
|
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||||
|
],
|
||||||
|
"StableDiffusion3_without_T5": [
|
||||||
|
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||||
|
],
|
||||||
|
# ControlNet
|
||||||
|
"ControlNet_v11f1p_sd15_depth": [
|
||||||
|
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||||
|
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||||
|
],
|
||||||
|
"ControlNet_v11p_sd15_softedge": [
|
||||||
|
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||||
|
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||||
|
],
|
||||||
|
"ControlNet_v11f1e_sd15_tile": [
|
||||||
|
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||||
|
],
|
||||||
|
"ControlNet_v11p_sd15_lineart": [
|
||||||
|
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||||
|
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
||||||
|
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
||||||
|
],
|
||||||
|
"ControlNet_union_sdxl_promax": [
|
||||||
|
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||||
|
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||||
|
],
|
||||||
|
# AnimateDiff
|
||||||
|
"AnimateDiff_v2": [
|
||||||
|
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||||
|
],
|
||||||
|
"AnimateDiff_xl_beta": [
|
||||||
|
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||||
|
],
|
||||||
|
|
||||||
|
# Qwen Prompt
|
||||||
|
"QwenPrompt": [
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
],
|
||||||
|
# Beautiful Prompt
|
||||||
|
"BeautifulPrompt": [
|
||||||
|
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
],
|
||||||
|
# Omost prompt
|
||||||
|
"OmostPrompt":[
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
],
|
||||||
|
# Translator
|
||||||
|
"opus-mt-zh-en": [
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
],
|
||||||
|
# IP-Adapter
|
||||||
|
"IP-Adapter-SD": [
|
||||||
|
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||||
|
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||||
|
],
|
||||||
|
"IP-Adapter-SDXL": [
|
||||||
|
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||||
|
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||||
|
],
|
||||||
|
"SDXL-vae-fp16-fix": [
|
||||||
|
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||||
|
],
|
||||||
|
# Kolors
|
||||||
|
"Kolors": [
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||||
|
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||||
|
],
|
||||||
|
# FLUX
|
||||||
|
"FLUX.1-dev": [
|
||||||
|
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||||
|
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
|
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
|
],
|
||||||
|
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||||
|
"file_list": [
|
||||||
|
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||||
|
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||||
|
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||||
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# RIFE
|
||||||
|
"RIFE": [
|
||||||
|
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
||||||
|
],
|
||||||
|
# CogVideo
|
||||||
|
"CogVideoX-5B": [
|
||||||
|
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion 3.5
|
||||||
|
"StableDiffusion3.5-large": [
|
||||||
|
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
||||||
|
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
preset_models_on_modelscope = {
|
||||||
|
# Hunyuan DiT
|
||||||
|
"HunyuanDiT": [
|
||||||
|
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||||
|
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||||
|
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||||
|
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||||
|
],
|
||||||
|
# Stable Video Diffusion
|
||||||
|
"stable-video-diffusion-img2vid-xt": [
|
||||||
|
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
# ExVideo
|
||||||
|
"ExVideo-SVD-128f-v1": [
|
||||||
|
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||||
|
],
|
||||||
|
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
||||||
|
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion
|
||||||
|
"StableDiffusion_v15": [
|
||||||
|
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||||
|
],
|
||||||
|
"DreamShaper_8": [
|
||||||
|
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||||
|
],
|
||||||
|
"AingDiffusion_v12": [
|
||||||
|
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
||||||
|
],
|
||||||
|
"Flat2DAnimerge_v45Sharp": [
|
||||||
|
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
||||||
|
],
|
||||||
|
# Textual Inversion
|
||||||
|
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||||
|
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion XL
|
||||||
|
"StableDiffusionXL_v1": [
|
||||||
|
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||||
|
],
|
||||||
|
"BluePencilXL_v200": [
|
||||||
|
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||||
|
],
|
||||||
|
"StableDiffusionXL_Turbo": [
|
||||||
|
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||||
|
],
|
||||||
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
||||||
|
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
||||||
|
],
|
||||||
|
# Stable Diffusion 3
|
||||||
|
"StableDiffusion3": [
|
||||||
|
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||||
|
],
|
||||||
|
"StableDiffusion3_without_T5": [
|
||||||
|
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||||
|
],
|
||||||
|
# ControlNet
|
||||||
|
"ControlNet_v11f1p_sd15_depth": [
|
||||||
|
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||||
|
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||||
|
],
|
||||||
|
"ControlNet_v11p_sd15_softedge": [
|
||||||
|
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||||
|
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||||
|
],
|
||||||
|
"ControlNet_v11f1e_sd15_tile": [
|
||||||
|
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||||
|
],
|
||||||
|
"ControlNet_v11p_sd15_lineart": [
|
||||||
|
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||||
|
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||||
|
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
||||||
|
],
|
||||||
|
"ControlNet_union_sdxl_promax": [
|
||||||
|
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||||
|
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||||
|
],
|
||||||
|
"Annotators:Depth": [
|
||||||
|
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
||||||
|
],
|
||||||
|
"Annotators:Softedge": [
|
||||||
|
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
||||||
|
],
|
||||||
|
"Annotators:Lineart": [
|
||||||
|
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||||
|
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
||||||
|
],
|
||||||
|
"Annotators:Normal": [
|
||||||
|
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
||||||
|
],
|
||||||
|
"Annotators:Openpose": [
|
||||||
|
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
||||||
|
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
||||||
|
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
||||||
|
],
|
||||||
|
# AnimateDiff
|
||||||
|
"AnimateDiff_v2": [
|
||||||
|
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||||
|
],
|
||||||
|
"AnimateDiff_xl_beta": [
|
||||||
|
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||||
|
],
|
||||||
|
# RIFE
|
||||||
|
"RIFE": [
|
||||||
|
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||||
|
],
|
||||||
|
# Qwen Prompt
|
||||||
|
"QwenPrompt": {
|
||||||
|
"file_list": [
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/QwenPrompt/qwen2-1.5b-instruct",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Beautiful Prompt
|
||||||
|
"BeautifulPrompt": {
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Omost prompt
|
||||||
|
"OmostPrompt": {
|
||||||
|
"file_list": [
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Translator
|
||||||
|
"opus-mt-zh-en": {
|
||||||
|
"file_list": [
|
||||||
|
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/translator/opus-mt-zh-en",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# IP-Adapter
|
||||||
|
"IP-Adapter-SD": [
|
||||||
|
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||||
|
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||||
|
],
|
||||||
|
"IP-Adapter-SDXL": [
|
||||||
|
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||||
|
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||||
|
],
|
||||||
|
# Kolors
|
||||||
|
"Kolors": {
|
||||||
|
"file_list": [
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
|
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||||
|
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/kolors/Kolors/text_encoder",
|
||||||
|
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"SDXL-vae-fp16-fix": [
|
||||||
|
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||||
|
],
|
||||||
|
# FLUX
|
||||||
|
"FLUX.1-dev": {
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"FLUX.1-schnell": {
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
|
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
||||||
|
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
||||||
|
],
|
||||||
|
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
||||||
|
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
||||||
|
],
|
||||||
|
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
||||||
|
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
||||||
|
],
|
||||||
|
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
||||||
|
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
||||||
|
],
|
||||||
|
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
||||||
|
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
||||||
|
],
|
||||||
|
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
||||||
|
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
||||||
|
],
|
||||||
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
||||||
|
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
||||||
|
],
|
||||||
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
||||||
|
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
||||||
|
],
|
||||||
|
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||||
|
"file_list": [
|
||||||
|
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||||
|
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||||
|
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||||
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"InfiniteYou":{
|
||||||
|
"file_list":[
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
],
|
||||||
|
"load_path":[
|
||||||
|
[
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||||
|
],
|
||||||
|
"models/InfiniteYou/image_proj_model.bin",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# ESRGAN
|
||||||
|
"ESRGAN_x4": [
|
||||||
|
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||||
|
],
|
||||||
|
# RIFE
|
||||||
|
"RIFE": [
|
||||||
|
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
||||||
|
],
|
||||||
|
# Omnigen
|
||||||
|
"OmniGen-v1": {
|
||||||
|
"file_list": [
|
||||||
|
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
||||||
|
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
||||||
|
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
||||||
|
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
||||||
|
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
||||||
|
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/OmniGen/OmniGen-v1/model.safetensors",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# CogVideo
|
||||||
|
"CogVideoX-5B": {
|
||||||
|
"file_list": [
|
||||||
|
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
|
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/CogVideo/CogVideoX-5b/text_encoder",
|
||||||
|
"models/CogVideo/CogVideoX-5b/transformer",
|
||||||
|
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Stable Diffusion 3.5
|
||||||
|
"StableDiffusion3.5-large": [
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
],
|
||||||
|
"StableDiffusion3.5-medium": [
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
],
|
||||||
|
"StableDiffusion3.5-large-turbo": [
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||||
|
],
|
||||||
|
"HunyuanVideo":{
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||||
|
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideo/text_encoder_2",
|
||||||
|
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||||
|
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"HunyuanVideoI2V":{
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||||
|
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"HunyuanVideo-fp8":{
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||||
|
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideo/text_encoder_2",
|
||||||
|
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||||
|
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Preset_model_id: TypeAlias = Literal[
|
||||||
|
"HunyuanDiT",
|
||||||
|
"stable-video-diffusion-img2vid-xt",
|
||||||
|
"ExVideo-SVD-128f-v1",
|
||||||
|
"ExVideo-CogVideoX-LoRA-129f-v1",
|
||||||
|
"StableDiffusion_v15",
|
||||||
|
"DreamShaper_8",
|
||||||
|
"AingDiffusion_v12",
|
||||||
|
"Flat2DAnimerge_v45Sharp",
|
||||||
|
"TextualInversion_VeryBadImageNegative_v1.3",
|
||||||
|
"StableDiffusionXL_v1",
|
||||||
|
"BluePencilXL_v200",
|
||||||
|
"StableDiffusionXL_Turbo",
|
||||||
|
"ControlNet_v11f1p_sd15_depth",
|
||||||
|
"ControlNet_v11p_sd15_softedge",
|
||||||
|
"ControlNet_v11f1e_sd15_tile",
|
||||||
|
"ControlNet_v11p_sd15_lineart",
|
||||||
|
"AnimateDiff_v2",
|
||||||
|
"AnimateDiff_xl_beta",
|
||||||
|
"RIFE",
|
||||||
|
"BeautifulPrompt",
|
||||||
|
"opus-mt-zh-en",
|
||||||
|
"IP-Adapter-SD",
|
||||||
|
"IP-Adapter-SDXL",
|
||||||
|
"StableDiffusion3",
|
||||||
|
"StableDiffusion3_without_T5",
|
||||||
|
"Kolors",
|
||||||
|
"SDXL-vae-fp16-fix",
|
||||||
|
"ControlNet_union_sdxl_promax",
|
||||||
|
"FLUX.1-dev",
|
||||||
|
"FLUX.1-schnell",
|
||||||
|
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
||||||
|
"jasperai/Flux.1-dev-Controlnet-Depth",
|
||||||
|
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
||||||
|
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
||||||
|
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
||||||
|
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
||||||
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||||
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||||
|
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||||
|
"InfiniteYou",
|
||||||
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
|
"QwenPrompt",
|
||||||
|
"OmostPrompt",
|
||||||
|
"ESRGAN_x4",
|
||||||
|
"RIFE",
|
||||||
|
"OmniGen-v1",
|
||||||
|
"CogVideoX-5B",
|
||||||
|
"Annotators:Depth",
|
||||||
|
"Annotators:Softedge",
|
||||||
|
"Annotators:Lineart",
|
||||||
|
"Annotators:Normal",
|
||||||
|
"Annotators:Openpose",
|
||||||
|
"StableDiffusion3.5-large",
|
||||||
|
"StableDiffusion3.5-medium",
|
||||||
|
"HunyuanVideo",
|
||||||
|
"HunyuanVideo-fp8",
|
||||||
|
"HunyuanVideoI2V",
|
||||||
|
]
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
||||||
from .processors import Annotator
|
from .processors import Annotator
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ from .processors import Processor_id
|
|||||||
|
|
||||||
|
|
||||||
class ControlNetConfigUnit:
|
class ControlNetConfigUnit:
|
||||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
||||||
self.processor_id = processor_id
|
self.processor_id = processor_id
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.skip_processor = skip_processor
|
||||||
|
|
||||||
|
|
||||||
class ControlNetUnit:
|
class ControlNetUnit:
|
||||||
@@ -23,6 +24,16 @@ class MultiControlNetManager:
|
|||||||
self.models = [unit.model for unit in controlnet_units]
|
self.models = [unit.model for unit in controlnet_units]
|
||||||
self.scales = [unit.scale for unit in controlnet_units]
|
self.scales = [unit.scale for unit in controlnet_units]
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
for model in self.models:
|
||||||
|
model.cpu()
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for model in self.models:
|
||||||
|
model.to(device)
|
||||||
|
for processor in self.processors:
|
||||||
|
processor.to(device)
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
def process_image(self, image, processor_id=None):
|
||||||
if processor_id is None:
|
if processor_id is None:
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
processed_image = [processor(image) for processor in self.processors]
|
||||||
@@ -37,13 +48,14 @@ class MultiControlNetManager:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
sample, timestep, encoder_hidden_states, conditionings,
|
sample, timestep, encoder_hidden_states, conditionings,
|
||||||
tiled=False, tile_size=64, tile_stride=32
|
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
||||||
):
|
):
|
||||||
res_stack = None
|
res_stack = None
|
||||||
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||||
res_stack_ = model(
|
res_stack_ = model(
|
||||||
sample, timestep, encoder_hidden_states, conditioning,
|
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
processor_id=processor.processor_id
|
||||||
)
|
)
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
res_stack_ = [res * scale for res in res_stack_]
|
||||||
if res_stack is None:
|
if res_stack is None:
|
||||||
@@ -51,3 +63,29 @@ class MultiControlNetManager:
|
|||||||
else:
|
else:
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||||
return res_stack
|
return res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class FluxMultiControlNetManager(MultiControlNetManager):
|
||||||
|
def __init__(self, controlnet_units=[]):
|
||||||
|
super().__init__(controlnet_units=controlnet_units)
|
||||||
|
|
||||||
|
def process_image(self, image, processor_id=None):
|
||||||
|
if processor_id is None:
|
||||||
|
processed_image = [processor(image) for processor in self.processors]
|
||||||
|
else:
|
||||||
|
processed_image = [self.processors[processor_id](image)]
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
def __call__(self, conditionings, **kwargs):
|
||||||
|
res_stack, single_res_stack = None, None
|
||||||
|
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||||
|
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
||||||
|
res_stack_ = [res * scale for res in res_stack_]
|
||||||
|
single_res_stack_ = [res * scale for res in single_res_stack_]
|
||||||
|
if res_stack is None:
|
||||||
|
res_stack = res_stack_
|
||||||
|
single_res_stack = single_res_stack_
|
||||||
|
else:
|
||||||
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||||
|
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
||||||
|
return res_stack, single_res_stack
|
||||||
|
|||||||
@@ -1,39 +1,50 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
from controlnet_aux.processor import (
|
|
||||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Processor_id: TypeAlias = Literal[
|
Processor_id: TypeAlias = Literal[
|
||||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
||||||
]
|
]
|
||||||
|
|
||||||
class Annotator:
|
class Annotator:
|
||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||||
if processor_id == "canny":
|
if not skip_processor:
|
||||||
self.processor = CannyDetector()
|
if processor_id == "canny":
|
||||||
elif processor_id == "depth":
|
from controlnet_aux.processor import CannyDetector
|
||||||
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
self.processor = CannyDetector()
|
||||||
elif processor_id == "softedge":
|
elif processor_id == "depth":
|
||||||
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
from controlnet_aux.processor import MidasDetector
|
||||||
elif processor_id == "lineart":
|
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||||
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
elif processor_id == "softedge":
|
||||||
elif processor_id == "lineart_anime":
|
from controlnet_aux.processor import HEDdetector
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "openpose":
|
elif processor_id == "lineart":
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
from controlnet_aux.processor import LineartDetector
|
||||||
elif processor_id == "tile":
|
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||||
self.processor = None
|
elif processor_id == "lineart_anime":
|
||||||
|
from controlnet_aux.processor import LineartAnimeDetector
|
||||||
|
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||||
|
elif processor_id == "openpose":
|
||||||
|
from controlnet_aux.processor import OpenposeDetector
|
||||||
|
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||||
|
elif processor_id == "normal":
|
||||||
|
from controlnet_aux.processor import NormalBaeDetector
|
||||||
|
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||||
|
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||||
|
self.processor = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
self.processor = None
|
||||||
|
|
||||||
self.processor_id = processor_id
|
self.processor_id = processor_id
|
||||||
self.detect_resolution = detect_resolution
|
self.detect_resolution = detect_resolution
|
||||||
|
|
||||||
|
def to(self,device):
|
||||||
|
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
|
||||||
|
|
||||||
def __call__(self, image):
|
self.processor.model.to(device)
|
||||||
|
|
||||||
|
def __call__(self, image, mask=None):
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if self.processor_id == "openpose":
|
if self.processor_id == "openpose":
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
|||||||
41
diffsynth/data/simple_text_image.py
Normal file
41
diffsynth/data/simple_text_image.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import torch, os, torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
import pandas as pd
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TextImageDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||||
|
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
|
self.text = metadata["text"].to_list()
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.image_processor = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||||
|
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||||
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
|
text = self.text[data_id]
|
||||||
|
image = Image.open(self.path[data_id]).convert("RGB")
|
||||||
|
target_height, target_width = self.height, self.width
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
shape = [round(height*scale),round(width*scale)]
|
||||||
|
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
||||||
|
image = self.image_processor(image)
|
||||||
|
return {"text": text, "image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
@@ -135,8 +135,8 @@ class VideoData:
|
|||||||
frame.save(os.path.join(folder, f"{i}.png"))
|
frame.save(os.path.join(folder, f"{i}.png"))
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames, save_path, fps, quality=9):
|
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||||
for frame in tqdm(frames, desc="Saving video"):
|
for frame in tqdm(frames, desc="Saving video"):
|
||||||
frame = np.array(frame)
|
frame = np.array(frame)
|
||||||
writer.append_data(frame)
|
writer.append_data(frame)
|
||||||
|
|||||||
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||||
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
return x.to(position.dtype)
|
||||||
|
|
||||||
|
def pad_freqs(original_tensor, target_len):
|
||||||
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
|
pad_size = target_len - seq_len
|
||||||
|
padding_tensor = torch.ones(
|
||||||
|
pad_size,
|
||||||
|
s1,
|
||||||
|
s2,
|
||||||
|
dtype=original_tensor.dtype,
|
||||||
|
device=original_tensor.device)
|
||||||
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def rope_apply(x, freqs, num_heads):
|
||||||
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
s_per_rank = x.shape[1]
|
||||||
|
|
||||||
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
|
|
||||||
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
def usp_dit_forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = self.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = self.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = torch.chunk(
|
||||||
|
x, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
x = self.head(x, t)
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_attn_forward(self, x, freqs):
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(x))
|
||||||
|
v = self.v(x)
|
||||||
|
|
||||||
|
q = rope_apply(q, freqs, self.num_heads)
|
||||||
|
k = rope_apply(k, freqs, self.num_heads)
|
||||||
|
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
|
||||||
|
x = xFuserLongContextAttention()(
|
||||||
|
None,
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
)
|
||||||
|
x = x.flatten(2)
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return self.o(x)
|
||||||
@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
|
|||||||
|
|
||||||
class RRDBNet(torch.nn.Module):
|
class RRDBNet(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
||||||
@@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module):
|
|||||||
feat = self.lrelu(self.conv_up2(feat))
|
feat = self.lrelu(self.conv_up2(feat))
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return RRDBNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBNetStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict, {"upcast_to_float32": True}
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict, {"upcast_to_float32": True}
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN(torch.nn.Module):
|
class ESRGAN(torch.nn.Module):
|
||||||
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(model_path):
|
def from_model_manager(model_manager):
|
||||||
model = RRDBNet()
|
return ESRGAN(model_manager.fetch_model("esrgan"))
|
||||||
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
return ESRGAN(model)
|
|
||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image):
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
||||||
@@ -96,6 +107,12 @@ class ESRGAN(torch.nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
is_single_image = True
|
||||||
|
else:
|
||||||
|
is_single_image = False
|
||||||
|
|
||||||
# Preprocess
|
# Preprocess
|
||||||
input_tensor = self.process_images(images)
|
input_tensor = self.process_images(images)
|
||||||
|
|
||||||
@@ -115,4 +132,6 @@ class ESRGAN(torch.nn.Module):
|
|||||||
|
|
||||||
# To images
|
# To images
|
||||||
output_images = self.decode_images(output_tensor)
|
output_images = self.decode_images(output_tensor)
|
||||||
|
if is_single_image:
|
||||||
|
output_images = output_images[0]
|
||||||
return output_images
|
return output_images
|
||||||
|
|||||||
1
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
Normal file
1
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .blip_pretrain import *
|
||||||
77
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
Normal file
77
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
'''
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from timm.models.hub import download_cached_file
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
from .vit import VisionTransformer, interpolate_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def default_bert():
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||||
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
return os.path.join(model_path, "bert-base-uncased")
|
||||||
|
|
||||||
|
|
||||||
|
def init_tokenizer(bert_model_path):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||||
|
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||||
|
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||||
|
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
||||||
|
|
||||||
|
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
||||||
|
if vit=='base':
|
||||||
|
vision_width = 768
|
||||||
|
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
||||||
|
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||||
|
drop_path_rate=0 or drop_path_rate
|
||||||
|
)
|
||||||
|
elif vit=='large':
|
||||||
|
vision_width = 1024
|
||||||
|
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
||||||
|
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||||
|
drop_path_rate=0.1 or drop_path_rate
|
||||||
|
)
|
||||||
|
return visual_encoder, vision_width
|
||||||
|
|
||||||
|
|
||||||
|
def is_url(url_or_filename):
|
||||||
|
parsed = urlparse(url_or_filename)
|
||||||
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
def load_checkpoint(model,url_or_filename):
|
||||||
|
if is_url(url_or_filename):
|
||||||
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||||
|
checkpoint = torch.load(cached_file, map_location='cpu')
|
||||||
|
elif os.path.isfile(url_or_filename):
|
||||||
|
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
||||||
|
else:
|
||||||
|
raise RuntimeError('checkpoint url or path is invalid')
|
||||||
|
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
|
||||||
|
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
||||||
|
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
||||||
|
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
||||||
|
model.visual_encoder_m)
|
||||||
|
for key in model.state_dict().keys():
|
||||||
|
if key in state_dict.keys():
|
||||||
|
if state_dict[key].shape!=model.state_dict()[key].shape:
|
||||||
|
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
msg = model.load_state_dict(state_dict,strict=False)
|
||||||
|
print('load checkpoint from %s'%url_or_filename)
|
||||||
|
return model,msg
|
||||||
|
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
'''
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import os
|
||||||
|
from .med import BertConfig, BertModel
|
||||||
|
from .blip import create_vit, init_tokenizer
|
||||||
|
|
||||||
|
class BLIP_Pretrain(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
med_config = "med_config.json",
|
||||||
|
image_size = 224,
|
||||||
|
vit = 'base',
|
||||||
|
vit_grad_ckpt = False,
|
||||||
|
vit_ckpt_layer = 0,
|
||||||
|
embed_dim = 256,
|
||||||
|
queue_size = 57600,
|
||||||
|
momentum = 0.995,
|
||||||
|
bert_model_path = ""
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||||
|
image_size (int): input image size
|
||||||
|
vit (str): model size of vision transformer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
||||||
|
|
||||||
|
self.tokenizer = init_tokenizer(bert_model_path)
|
||||||
|
encoder_config = BertConfig.from_json_file(med_config)
|
||||||
|
encoder_config.encoder_width = vision_width
|
||||||
|
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
||||||
|
|
||||||
|
text_width = self.text_encoder.config.hidden_size
|
||||||
|
|
||||||
|
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
||||||
|
self.text_proj = nn.Linear(text_width, embed_dim)
|
||||||
|
|
||||||
947
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
Normal file
947
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
Normal file
@@ -0,0 +1,947 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
* Based on huggingface code base
|
||||||
|
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
||||||
|
'''
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, device, nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.file_utils import (
|
||||||
|
ModelOutput,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
MaskedLMOutput,
|
||||||
|
MultipleChoiceModelOutput,
|
||||||
|
NextSentencePredictorOutput,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutput,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.models.bert.configuration_bert import BertConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word and position embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
|
||||||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||||
|
# any TensorFlow checkpoint file
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
|
if input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
else:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
embeddings = inputs_embeds
|
||||||
|
|
||||||
|
if self.position_embedding_type == "absolute":
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
embeddings += position_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config, is_cross_attention):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
if is_cross_attention:
|
||||||
|
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
||||||
|
else:
|
||||||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
self.save_attention = False
|
||||||
|
|
||||||
|
def save_attn_gradients(self, attn_gradients):
|
||||||
|
self.attn_gradients = attn_gradients
|
||||||
|
|
||||||
|
def get_attn_gradients(self):
|
||||||
|
return self.attn_gradients
|
||||||
|
|
||||||
|
def save_attention_map(self, attention_map):
|
||||||
|
self.attention_map = attention_map
|
||||||
|
|
||||||
|
def get_attention_map(self):
|
||||||
|
return self.attention_map
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
seq_length = hidden_states.size()[1]
|
||||||
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||||
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||||
|
distance = position_ids_l - position_ids_r
|
||||||
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||||
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key":
|
||||||
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores
|
||||||
|
elif self.position_embedding_type == "relative_key_query":
|
||||||
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||||
|
|
||||||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
|
|
||||||
|
if is_cross_attention and self.save_attention:
|
||||||
|
self.save_attention_map(attention_probs)
|
||||||
|
attention_probs.register_hook(self.save_attn_gradients)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs_dropped = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs_dropped = attention_probs_dropped * head_mask
|
||||||
|
|
||||||
|
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertAttention(nn.Module):
|
||||||
|
def __init__(self, config, is_cross_attention=False):
|
||||||
|
super().__init__()
|
||||||
|
self.self = BertSelfAttention(config, is_cross_attention)
|
||||||
|
self.output = BertSelfOutput(config)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prune linear layers
|
||||||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
|
|
||||||
|
# Update hyper params and store pruned heads
|
||||||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
self_outputs = self.self(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertIntermediate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_num):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
self.attention = BertAttention(config)
|
||||||
|
self.layer_num = layer_num
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
||||||
|
self.intermediate = BertIntermediate(config)
|
||||||
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
mode=None,
|
||||||
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
|
||||||
|
if mode=='multimodal':
|
||||||
|
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
||||||
|
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
attention_output,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
layer_output = apply_chunking_to_forward(
|
||||||
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
|
)
|
||||||
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def feed_forward_chunk(self, attention_output):
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for i in range(self.config.num_hidden_layers):
|
||||||
|
layer_module = self.layer[i]
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(layer_module),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BertPooler(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
|
# to the first token.
|
||||||
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertPredictionHeadTransform(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.transform_act_fn = config.hidden_act
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLMPredictionHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
|
|
||||||
|
# The output weights are the same as the input embeddings, but there is
|
||||||
|
# an output-only bias for each token.
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.transform(hidden_states)
|
||||||
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOnlyMLMHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.predictions = BertLMPredictionHead(config)
|
||||||
|
|
||||||
|
def forward(self, sequence_output):
|
||||||
|
prediction_scores = self.predictions(sequence_output)
|
||||||
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
|
class BertPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = BertConfig
|
||||||
|
base_model_prefix = "bert"
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
""" Initialize the weights """
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class BertModel(BertPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||||
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||||
|
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||||
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
||||||
|
input to the forward pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
|
||||||
|
self.encoder = BertEncoder(config)
|
||||||
|
|
||||||
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
"""
|
||||||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||||
|
class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
|
||||||
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
||||||
|
"""
|
||||||
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
attention_mask (:obj:`torch.Tensor`):
|
||||||
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||||
|
input_shape (:obj:`Tuple[int]`):
|
||||||
|
The shape of the input to the model.
|
||||||
|
device: (:obj:`torch.device`):
|
||||||
|
The device of the input to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
||||||
|
"""
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
if attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if is_decoder:
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||||
|
# causal and attention masks must have same type with pytorch version < 1.3
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
|
||||||
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||||
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||||
|
causal_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
||||||
|
causal_mask,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||||
|
input_shape, attention_mask.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
return extended_attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
is_decoder=False,
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = input_ids.device
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
elif encoder_embeds is not None:
|
||||||
|
input_shape = encoder_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = encoder_embeds.device
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
||||||
|
device, is_decoder)
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if type(encoder_hidden_states) == list:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||||
|
else:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
|
||||||
|
if type(encoder_attention_mask) == list:
|
||||||
|
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||||
|
elif encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
if encoder_embeds is None:
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_output = encoder_embeds
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
|
self.cls = BertOnlyMLMHead(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
return_logits=False,
|
||||||
|
is_decoder=True,
|
||||||
|
reduction='mean',
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
Returns:
|
||||||
|
Example::
|
||||||
|
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
||||||
|
>>> import torch
|
||||||
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
|
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
||||||
|
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> prediction_logits = outputs.logits
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
is_decoder=is_decoder,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.cls(sequence_output)
|
||||||
|
|
||||||
|
if return_logits:
|
||||||
|
return prediction_scores[:, :-1, :].contiguous()
|
||||||
|
|
||||||
|
lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
||||||
|
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
if reduction=='none':
|
||||||
|
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (prediction_scores,) + outputs[2:]
|
||||||
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=lm_loss,
|
||||||
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
input_shape = input_ids.shape
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
||||||
|
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
||||||
|
"is_decoder": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
301
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
Normal file
301
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
* Based on timm code base
|
||||||
|
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from timm.models.vision_transformer import _cfg, PatchEmbed
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath
|
||||||
|
from timm.models.helpers import named_apply, adapt_input_conv
|
||||||
|
|
||||||
|
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
|
"""
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.attn_gradients = None
|
||||||
|
self.attention_map = None
|
||||||
|
|
||||||
|
def save_attn_gradients(self, attn_gradients):
|
||||||
|
self.attn_gradients = attn_gradients
|
||||||
|
|
||||||
|
def get_attn_gradients(self):
|
||||||
|
return self.attn_gradients
|
||||||
|
|
||||||
|
def save_attention_map(self, attention_map):
|
||||||
|
self.attention_map = attention_map
|
||||||
|
|
||||||
|
def get_attention_map(self):
|
||||||
|
return self.attention_map
|
||||||
|
|
||||||
|
def forward(self, x, register_hook=False):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
if register_hook:
|
||||||
|
self.save_attention_map(attn)
|
||||||
|
attn.register_hook(self.save_attn_gradients)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||||
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# if use_grad_checkpointing:
|
||||||
|
# self.attn = checkpoint_wrapper(self.attn)
|
||||||
|
# self.mlp = checkpoint_wrapper(self.mlp)
|
||||||
|
|
||||||
|
def forward(self, x, register_hook=False):
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
""" Vision Transformer
|
||||||
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
||||||
|
https://arxiv.org/abs/2010.11929
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
||||||
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
||||||
|
use_grad_checkpointing=False, ckpt_layer=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_size (int, tuple): input image size
|
||||||
|
patch_size (int, tuple): patch size
|
||||||
|
in_chans (int): number of input channels
|
||||||
|
num_classes (int): number of classes for classification head
|
||||||
|
embed_dim (int): embedding dimension
|
||||||
|
depth (int): depth of transformer
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||||
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||||
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||||
|
drop_rate (float): dropout rate
|
||||||
|
attn_drop_rate (float): attention dropout rate
|
||||||
|
drop_path_rate (float): stochastic depth rate
|
||||||
|
norm_layer: (nn.Module): normalization layer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
|
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Block(
|
||||||
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||||
|
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
||||||
|
)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'pos_embed', 'cls_token'}
|
||||||
|
|
||||||
|
def forward(self, x, register_blk=-1):
|
||||||
|
B = x.shape[0]
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
x = x + self.pos_embed[:,:x.size(1),:]
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
for i,blk in enumerate(self.blocks):
|
||||||
|
x = blk(x, register_blk==i)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.jit.ignore()
|
||||||
|
def load_pretrained(self, checkpoint_path, prefix=''):
|
||||||
|
_load_weights(self, checkpoint_path, prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||||
|
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def _n2p(w, t=True):
|
||||||
|
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||||
|
w = w.flatten()
|
||||||
|
if t:
|
||||||
|
if w.ndim == 4:
|
||||||
|
w = w.transpose([3, 2, 0, 1])
|
||||||
|
elif w.ndim == 3:
|
||||||
|
w = w.transpose([2, 0, 1])
|
||||||
|
elif w.ndim == 2:
|
||||||
|
w = w.transpose([1, 0])
|
||||||
|
return torch.from_numpy(w)
|
||||||
|
|
||||||
|
w = np.load(checkpoint_path)
|
||||||
|
if not prefix and 'opt/target/embedding/kernel' in w:
|
||||||
|
prefix = 'opt/target/'
|
||||||
|
|
||||||
|
if hasattr(model.patch_embed, 'backbone'):
|
||||||
|
# hybrid
|
||||||
|
backbone = model.patch_embed.backbone
|
||||||
|
stem_only = not hasattr(backbone, 'stem')
|
||||||
|
stem = backbone if stem_only else backbone.stem
|
||||||
|
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
||||||
|
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
||||||
|
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
||||||
|
if not stem_only:
|
||||||
|
for i, stage in enumerate(backbone.stages):
|
||||||
|
for j, block in enumerate(stage.blocks):
|
||||||
|
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
||||||
|
for r in range(3):
|
||||||
|
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
||||||
|
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
||||||
|
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
||||||
|
if block.downsample is not None:
|
||||||
|
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
||||||
|
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
||||||
|
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
||||||
|
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
||||||
|
else:
|
||||||
|
embed_conv_w = adapt_input_conv(
|
||||||
|
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||||
|
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||||
|
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||||
|
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||||
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||||
|
if pos_embed_w.shape != model.pos_embed.shape:
|
||||||
|
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||||
|
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||||
|
model.pos_embed.copy_(pos_embed_w)
|
||||||
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||||
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||||
|
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||||
|
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||||
|
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||||
|
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||||
|
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||||
|
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||||
|
for i, block in enumerate(model.blocks.children()):
|
||||||
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||||
|
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
||||||
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||||
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||||
|
block.attn.qkv.weight.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||||
|
block.attn.qkv.bias.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||||
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||||
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||||
|
for r in range(2):
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
||||||
|
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
||||||
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
||||||
|
# interpolate position embedding
|
||||||
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||||
|
num_patches = visual_encoder.patch_embed.num_patches
|
||||||
|
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
||||||
|
# height (== width) for the checkpoint position embedding
|
||||||
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||||
|
# height (== width) for the new position embedding
|
||||||
|
new_size = int(num_patches ** 0.5)
|
||||||
|
|
||||||
|
if orig_size!=new_size:
|
||||||
|
# class_token and dist_token are kept unchanged
|
||||||
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||||
|
# only the position tokens are interpolated
|
||||||
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||||
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||||
|
pos_tokens = torch.nn.functional.interpolate(
|
||||||
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||||
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||||
|
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
||||||
|
|
||||||
|
return new_pos_embed
|
||||||
|
else:
|
||||||
|
return pos_embed_checkpoint
|
||||||
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from modelscope import snapshot_download
|
||||||
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
import os
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
||||||
|
|
||||||
|
|
||||||
|
preference_model_id: TypeAlias = Literal[
|
||||||
|
"ImageReward",
|
||||||
|
"Aesthetic",
|
||||||
|
"PickScore",
|
||||||
|
"CLIP",
|
||||||
|
"HPSv2",
|
||||||
|
"HPSv2.1",
|
||||||
|
"MPS",
|
||||||
|
]
|
||||||
|
model_dict = {
|
||||||
|
"ImageReward": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"ImageReward/ImageReward.safetensors",
|
||||||
|
"ImageReward/med_config.json",
|
||||||
|
"bert-base-uncased/config.json",
|
||||||
|
"bert-base-uncased/model.safetensors",
|
||||||
|
"bert-base-uncased/tokenizer.json",
|
||||||
|
"bert-base-uncased/tokenizer_config.json",
|
||||||
|
"bert-base-uncased/vocab.txt",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"imagereward": "ImageReward/ImageReward.safetensors",
|
||||||
|
"med_config": "ImageReward/med_config.json",
|
||||||
|
"bert_model_path": "bert-base-uncased",
|
||||||
|
},
|
||||||
|
"model_class": ImageRewardScore
|
||||||
|
},
|
||||||
|
"Aesthetic": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||||
|
"clip-vit-large-patch14/config.json",
|
||||||
|
"clip-vit-large-patch14/merges.txt",
|
||||||
|
"clip-vit-large-patch14/model.safetensors",
|
||||||
|
"clip-vit-large-patch14/preprocessor_config.json",
|
||||||
|
"clip-vit-large-patch14/special_tokens_map.json",
|
||||||
|
"clip-vit-large-patch14/tokenizer.json",
|
||||||
|
"clip-vit-large-patch14/tokenizer_config.json",
|
||||||
|
"clip-vit-large-patch14/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||||
|
"clip-large": "clip-vit-large-patch14",
|
||||||
|
},
|
||||||
|
"model_class": AestheticScore
|
||||||
|
},
|
||||||
|
"PickScore": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"PickScore_v1/*",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"pickscore": "PickScore_v1",
|
||||||
|
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||||
|
},
|
||||||
|
"model_class": PickScore
|
||||||
|
},
|
||||||
|
"CLIP": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": CLIPScore
|
||||||
|
},
|
||||||
|
"HPSv2": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"HPS_v2/HPS_v2_compressed.safetensors",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": HPScore_v2,
|
||||||
|
"extra_kwargs": {"model_version": "v2"}
|
||||||
|
},
|
||||||
|
"HPSv2.1": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": HPScore_v2,
|
||||||
|
"extra_kwargs": {"model_version": "v21"}
|
||||||
|
},
|
||||||
|
"MPS": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||||
|
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||||
|
},
|
||||||
|
"model_class": MPScore
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
||||||
|
metadata = model_dict[model_name]
|
||||||
|
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
||||||
|
load_path = metadata["load_path"]
|
||||||
|
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
||||||
|
return load_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
||||||
|
model_class = model_dict[model_name]["model_class"]
|
||||||
|
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
||||||
|
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
||||||
|
return preference_model
|
||||||
148
diffsynth/extensions/ImageQualityMetric/aesthetic.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/aesthetic.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
import os
|
||||||
|
from typing import Union, List
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.xcol = xcol
|
||||||
|
self.ycol = ycol
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.input_size, 1024),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(1024, 128),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(128, 64),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.1),
|
||||||
|
torch.nn.Linear(64, 16),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(16, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
||||||
|
x = batch[self.xcol]
|
||||||
|
y = batch[self.ycol].reshape(-1, 1)
|
||||||
|
x_hat = self.layers(x)
|
||||||
|
loss = torch.nn.functional.mse_loss(x_hat, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
||||||
|
x = batch[self.xcol]
|
||||||
|
y = batch[self.ycol].reshape(-1, 1)
|
||||||
|
x_hat = self.layers(x)
|
||||||
|
loss = torch.nn.functional.mse_loss(x_hat, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> torch.optim.Optimizer:
|
||||||
|
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class AestheticScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.aes_model_path = path.get("aesthetic_predictor")
|
||||||
|
# Load the MLP model
|
||||||
|
self.model = MLP(768)
|
||||||
|
try:
|
||||||
|
if self.aes_model_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_file(self.aes_model_path)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(self.aes_model_path)
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
||||||
|
|
||||||
|
self.model.to(device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Load the CLIP model and processor
|
||||||
|
clip_model_name = path.get('clip-large')
|
||||||
|
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor) -> float:
|
||||||
|
"""Calculate the aesthetic score for a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The aesthetic score.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Get image embeddings
|
||||||
|
image_embs = self.model2.get_image_features(image)
|
||||||
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Compute score
|
||||||
|
score = self.model(image_embs).cpu().flatten().item()
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||||
|
"""Score the images based on their aesthetic quality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
pil_image = Image.open(images)
|
||||||
|
else:
|
||||||
|
pil_image = images
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
return [self._calculate_score(image_inputs["pixel_values"])]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_image in images:
|
||||||
|
if isinstance(one_image, str):
|
||||||
|
pil_image = Image.open(one_image)
|
||||||
|
elif isinstance(one_image, Image.Image):
|
||||||
|
pil_image = one_image
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
scores.append(self._calculate_score(image_inputs["pixel_values"]))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
97
diffsynth/extensions/ImageQualityMetric/clip.py
Normal file
97
diffsynth/extensions/ImageQualityMetric/clip.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class CLIPScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the CLIPScore with a model and tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): The device to load the model on.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Create model and transforms
|
||||||
|
self.model, _, self.preprocess_val = create_model_and_transforms(
|
||||||
|
"ViT-H-14",
|
||||||
|
# "laion2B-s32B-b79K",
|
||||||
|
pretrained=path.get("open_clip"),
|
||||||
|
precision="amp",
|
||||||
|
device=device,
|
||||||
|
jit=False,
|
||||||
|
force_quick_gelu=False,
|
||||||
|
force_custom_text=False,
|
||||||
|
force_patch_dropout=False,
|
||||||
|
force_image_size=None,
|
||||||
|
pretrained_image=False,
|
||||||
|
image_mean=None,
|
||||||
|
image_std=None,
|
||||||
|
light_augmentation=True,
|
||||||
|
aug_cfg={},
|
||||||
|
output_dict=True,
|
||||||
|
with_score_predictor=False,
|
||||||
|
with_region_predictor=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize tokenizer
|
||||||
|
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
"""Calculate the CLIP score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The CLIP score.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Process the prompt
|
||||||
|
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Calculate the CLIP score
|
||||||
|
outputs = self.model(image, text)
|
||||||
|
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
||||||
|
logits_per_image = image_features @ text_features.T
|
||||||
|
clip_score = torch.diagonal(logits_per_image).cpu().numpy()
|
||||||
|
|
||||||
|
return clip_score[0].item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of CLIP scores for the images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
return [self._calculate_score(image, prompt)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
23
diffsynth/extensions/ImageQualityMetric/config.py
Normal file
23
diffsynth/extensions/ImageQualityMetric/config.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
||||||
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model_name):
|
||||||
|
return os.path.join(model_path, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_PATHS = {
|
||||||
|
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
||||||
|
"open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
||||||
|
"hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
|
||||||
|
"hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
|
||||||
|
"imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
|
||||||
|
"med_config": get_model_path("ImageReward/med_config.json"),
|
||||||
|
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||||
|
"clip-large": get_model_path("clip-vit-large-patch14"),
|
||||||
|
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||||
|
"pickscore": get_model_path("PickScore_v1")
|
||||||
|
}
|
||||||
118
diffsynth/extensions/ImageQualityMetric/hps.py
Normal file
118
diffsynth/extensions/ImageQualityMetric/hps.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
import os
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class HPScore_v2(torch.nn.Module):
|
||||||
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the Selector with a model and tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): The device to load the model on.
|
||||||
|
model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
if model_version == "v2":
|
||||||
|
safetensors_path = path.get("hpsv2")
|
||||||
|
elif model_version == "v21":
|
||||||
|
safetensors_path = path.get("hpsv2.1")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
||||||
|
|
||||||
|
# Create model and transforms
|
||||||
|
model, _, self.preprocess_val = create_model_and_transforms(
|
||||||
|
"ViT-H-14",
|
||||||
|
# "laion2B-s32B-b79K",
|
||||||
|
pretrained=path.get("open_clip"),
|
||||||
|
precision="amp",
|
||||||
|
device=device,
|
||||||
|
jit=False,
|
||||||
|
force_quick_gelu=False,
|
||||||
|
force_custom_text=False,
|
||||||
|
force_patch_dropout=False,
|
||||||
|
force_image_size=None,
|
||||||
|
pretrained_image=False,
|
||||||
|
image_mean=None,
|
||||||
|
image_std=None,
|
||||||
|
light_augmentation=True,
|
||||||
|
aug_cfg={},
|
||||||
|
output_dict=True,
|
||||||
|
with_score_predictor=False,
|
||||||
|
with_region_predictor=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
try:
|
||||||
|
state_dict = load_file(safetensors_path)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
||||||
|
|
||||||
|
# Initialize tokenizer and model
|
||||||
|
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
"""Calculate the HPS score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The HPS score.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Process the prompt
|
||||||
|
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Calculate the HPS score
|
||||||
|
outputs = self.model(image, text)
|
||||||
|
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
||||||
|
logits_per_image = image_features @ text_features.T
|
||||||
|
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
||||||
|
|
||||||
|
return hps_score[0].item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of HPS scores for the images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
return [self._calculate_score(image, prompt)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
212
diffsynth/extensions/ImageQualityMetric/imagereward.py
Normal file
212
diffsynth/extensions/ImageQualityMetric/imagereward.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Union
|
||||||
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
|
from .BLIP.blip_pretrain import BLIP_Pretrain
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
BICUBIC = InterpolationMode.BICUBIC
|
||||||
|
|
||||||
|
def _convert_image_to_rgb(image):
|
||||||
|
return image.convert("RGB")
|
||||||
|
|
||||||
|
def _transform(n_px):
|
||||||
|
return Compose([
|
||||||
|
Resize(n_px, interpolation=BICUBIC),
|
||||||
|
CenterCrop(n_px),
|
||||||
|
_convert_image_to_rgb,
|
||||||
|
ToTensor(),
|
||||||
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
])
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, input_size):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.input_size, 1024),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(1024, 128),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(128, 64),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.1),
|
||||||
|
torch.nn.Linear(64, 16),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Linear(16, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# initial MLP param
|
||||||
|
for name, param in self.layers.named_parameters():
|
||||||
|
if 'weight' in name:
|
||||||
|
torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
|
||||||
|
if 'bias' in name:
|
||||||
|
torch.nn.init.constant_(param, val=0)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.layers(input)
|
||||||
|
|
||||||
|
class ImageReward(torch.nn.Module):
|
||||||
|
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
||||||
|
self.preprocess = _transform(224)
|
||||||
|
self.mlp = MLP(768)
|
||||||
|
|
||||||
|
self.mean = 0.16717362830052426
|
||||||
|
self.std = 1.0333394966054072
|
||||||
|
|
||||||
|
def score_grad(self, prompt_ids, prompt_attention_mask, image):
|
||||||
|
"""Calculate the score with gradient for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_ids (torch.Tensor): Tokenized prompt IDs.
|
||||||
|
prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The reward score.
|
||||||
|
"""
|
||||||
|
image_embeds = self.blip.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
text_output = self.blip.text_encoder(
|
||||||
|
prompt_ids,
|
||||||
|
attention_mask=prompt_attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
txt_features = text_output.last_hidden_state[:, 0, :]
|
||||||
|
rewards = self.mlp(txt_features)
|
||||||
|
rewards = (rewards - self.mean) / self.std
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
pil_image = Image.open(images)
|
||||||
|
else:
|
||||||
|
pil_image = images
|
||||||
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||||
|
return [self._calculate_score(prompt, image).item()]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_image in images:
|
||||||
|
if isinstance(one_image, str):
|
||||||
|
pil_image = Image.open(one_image)
|
||||||
|
elif isinstance(one_image, Image.Image):
|
||||||
|
pil_image = one_image
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||||
|
scores.append(self._calculate_score(prompt, image).item())
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
|
||||||
|
def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate the score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The reward score.
|
||||||
|
"""
|
||||||
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
||||||
|
image_embeds = self.blip.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
text_output = self.blip.text_encoder(
|
||||||
|
text_input.input_ids,
|
||||||
|
attention_mask=text_input.attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
txt_features = text_output.last_hidden_state[:, 0, :].float()
|
||||||
|
rewards = self.mlp(txt_features)
|
||||||
|
rewards = (rewards - self.mean) / self.std
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
|
||||||
|
"""Rank the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
|
||||||
|
"""
|
||||||
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
||||||
|
txt_set = []
|
||||||
|
for generation in generations_list:
|
||||||
|
if isinstance(generation, str):
|
||||||
|
pil_image = Image.open(generation)
|
||||||
|
elif isinstance(generation, Image.Image):
|
||||||
|
pil_image = generation
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter generations_list is illegal.")
|
||||||
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||||
|
image_embeds = self.blip.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
text_output = self.blip.text_encoder(
|
||||||
|
text_input.input_ids,
|
||||||
|
attention_mask=text_input.attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
||||||
|
txt_features = torch.cat(txt_set, 0).float()
|
||||||
|
rewards = self.mlp(txt_features)
|
||||||
|
rewards = (rewards - self.mean) / self.std
|
||||||
|
rewards = torch.squeeze(rewards)
|
||||||
|
_, rank = torch.sort(rewards, dim=0, descending=True)
|
||||||
|
_, indices = torch.sort(rank, dim=0)
|
||||||
|
indices = indices + 1
|
||||||
|
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRewardScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
|
model_path = path.get("imagereward")
|
||||||
|
med_config = path.get("med_config")
|
||||||
|
state_dict = load_file(model_path)
|
||||||
|
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
||||||
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
return self.model.score(images, prompt)
|
||||||
129
diffsynth/extensions/ImageQualityMetric/mps.py
Normal file
129
diffsynth/extensions/ImageQualityMetric/mps.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
||||||
|
from transformers import CLIPConfig
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers import CLIPModel as HFCLIPModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch import nn, einsum
|
||||||
|
|
||||||
|
from .trainer.models.base_model import BaseModelConfig
|
||||||
|
|
||||||
|
from transformers import CLIPConfig
|
||||||
|
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
||||||
|
from typing import Any, Optional, Tuple, Union, List
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .trainer.models.cross_modeling import Cross_model
|
||||||
|
from .trainer.models import clip_model
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class MPScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
processor_name_or_path = path.get("clip")
|
||||||
|
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||||
|
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
||||||
|
state_dict = load_file(path.get("mps"))
|
||||||
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
self.model.to(device)
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
"""Calculate the reward score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The reward score.
|
||||||
|
"""
|
||||||
|
def _tokenize(caption):
|
||||||
|
input_ids = self.tokenizer(
|
||||||
|
caption,
|
||||||
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).input_ids
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
text_input = _tokenize(prompt).to(self.device)
|
||||||
|
if self.condition == 'overall':
|
||||||
|
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
|
||||||
|
elif self.condition == 'aesthetics':
|
||||||
|
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
|
||||||
|
elif self.condition == 'quality':
|
||||||
|
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
|
||||||
|
elif self.condition == 'semantic':
|
||||||
|
condition_prompt = 'quantity, attributes, position, number, location'
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
|
||||||
|
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
text_f, text_features = self.model.model.get_text_features(text_input)
|
||||||
|
|
||||||
|
image_f = self.model.model.get_image_features(image.half())
|
||||||
|
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
||||||
|
|
||||||
|
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||||
|
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||||
|
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||||
|
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
||||||
|
mask = mask.repeat(1, image_f.shape[1], 1)
|
||||||
|
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
||||||
|
|
||||||
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||||
|
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
||||||
|
|
||||||
|
return image_score[0].cpu().numpy().item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of reward scores for the images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
else:
|
||||||
|
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
return [self._calculate_score(image, prompt)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from .coca_model import CoCa
|
||||||
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||||
|
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
||||||
|
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
||||||
|
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||||
|
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
||||||
|
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
||||||
|
from .openai import load_openai_model, list_openai_models
|
||||||
|
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
||||||
|
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
||||||
|
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
||||||
|
from .tokenizer import SimpleTokenizer
|
||||||
|
from .transform import image_transform, AugmentationCfg
|
||||||
|
from .utils import freeze_batch_norm_2d
|
||||||
458
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
Normal file
458
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import numpy as np
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .transformer import (
|
||||||
|
LayerNormFp32,
|
||||||
|
LayerNorm,
|
||||||
|
QuickGELU,
|
||||||
|
MultimodalTransformer,
|
||||||
|
)
|
||||||
|
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import (
|
||||||
|
BeamSearchScorer,
|
||||||
|
LogitsProcessorList,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
MinLengthLogitsProcessor,
|
||||||
|
MaxLengthCriteria,
|
||||||
|
StoppingCriteriaList
|
||||||
|
)
|
||||||
|
|
||||||
|
GENERATION_TYPES = {
|
||||||
|
"top_k": TopKLogitsWarper,
|
||||||
|
"top_p": TopPLogitsWarper,
|
||||||
|
"beam_search": "beam_search"
|
||||||
|
}
|
||||||
|
_has_transformers = True
|
||||||
|
except ImportError as e:
|
||||||
|
GENERATION_TYPES = {
|
||||||
|
"top_k": None,
|
||||||
|
"top_p": None,
|
||||||
|
"beam_search": "beam_search"
|
||||||
|
}
|
||||||
|
_has_transformers = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultimodalCfg(CLIPTextCfg):
|
||||||
|
mlp_ratio: int = 4
|
||||||
|
dim_head: int = 64
|
||||||
|
heads: int = 8
|
||||||
|
n_queries: int = 256
|
||||||
|
attn_pooler_heads: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_decoder_tower(
|
||||||
|
embed_dim,
|
||||||
|
multimodal_cfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
||||||
|
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||||
|
norm_layer = (
|
||||||
|
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = MultimodalTransformer(
|
||||||
|
context_length=multimodal_cfg.context_length,
|
||||||
|
width=multimodal_cfg.width,
|
||||||
|
heads=multimodal_cfg.heads,
|
||||||
|
layers=multimodal_cfg.layers,
|
||||||
|
ls_init_value=multimodal_cfg.ls_init_value,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
|
class CoCa(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
multimodal_cfg: MultimodalCfg,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
pad_id: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
||||||
|
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
||||||
|
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
||||||
|
|
||||||
|
self.text = _build_text_tower(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
text_cfg=text_cfg,
|
||||||
|
quick_gelu=quick_gelu,
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocab_size = (
|
||||||
|
text_cfg.vocab_size # for hf models
|
||||||
|
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
||||||
|
else text_cfg.vocab_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual = _build_vision_tower(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
vision_cfg=vision_cfg,
|
||||||
|
quick_gelu=quick_gelu,
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_decoder = _build_text_decoder_tower(
|
||||||
|
vocab_size,
|
||||||
|
multimodal_cfg=multimodal_cfg,
|
||||||
|
quick_gelu=quick_gelu,
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
self.pad_id = pad_id
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.visual.set_grad_checkpointing(enable)
|
||||||
|
self.text.set_grad_checkpointing(enable)
|
||||||
|
self.text_decoder.set_grad_checkpointing(enable)
|
||||||
|
|
||||||
|
def _encode_image(self, images, normalize=True):
|
||||||
|
image_latent, tokens_embs = self.visual(images)
|
||||||
|
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
||||||
|
return image_latent, tokens_embs
|
||||||
|
|
||||||
|
def _encode_text(self, text, normalize=True, embed_cls=True):
|
||||||
|
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
||||||
|
text_latent, token_emb = self.text(text)
|
||||||
|
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
||||||
|
return text_latent, token_emb
|
||||||
|
|
||||||
|
def encode_image(self, images, normalize=True):
|
||||||
|
image_latent, _ = self._encode_image(images, normalize=normalize)
|
||||||
|
return image_latent
|
||||||
|
|
||||||
|
def encode_text(self, text, normalize=True, embed_cls=True):
|
||||||
|
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
||||||
|
return text_latent
|
||||||
|
|
||||||
|
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
||||||
|
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
||||||
|
if image_latent is None or image_embs is None:
|
||||||
|
image_latent, image_embs = self._encode_image(image)
|
||||||
|
|
||||||
|
# TODO: add assertion to avoid bugs?
|
||||||
|
labels = text[:, -token_embs.shape[1]:]
|
||||||
|
|
||||||
|
logits = self.text_decoder(image_embs, token_embs)
|
||||||
|
return {
|
||||||
|
"image_features": image_latent,
|
||||||
|
"text_features": text_latent,
|
||||||
|
"logits": logits,
|
||||||
|
"labels": labels,
|
||||||
|
"logit_scale": self.logit_scale.exp()
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
text=None,
|
||||||
|
seq_len=30,
|
||||||
|
max_seq_len=77,
|
||||||
|
temperature=1.,
|
||||||
|
generation_type="beam_search",
|
||||||
|
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
||||||
|
top_k=1, # keeps the top_k most probable tokens
|
||||||
|
pad_token_id=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
sot_token_id=None,
|
||||||
|
num_beams=6,
|
||||||
|
num_beam_groups=3,
|
||||||
|
min_seq_len=5,
|
||||||
|
stopping_criteria=None,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
||||||
|
):
|
||||||
|
# taking many ideas and components from HuggingFace GenerationMixin
|
||||||
|
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
||||||
|
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
||||||
|
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
||||||
|
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
||||||
|
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
||||||
|
logit_processor = LogitsProcessorList(
|
||||||
|
[
|
||||||
|
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
||||||
|
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if stopping_criteria is None:
|
||||||
|
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList(
|
||||||
|
stopping_criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
device = image.device
|
||||||
|
|
||||||
|
if generation_type == "beam_search":
|
||||||
|
output = self._generate_beamsearch(
|
||||||
|
image_inputs = image,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
sot_token_id=sot_token_id,
|
||||||
|
num_beams=num_beams,
|
||||||
|
num_beam_groups=num_beam_groups,
|
||||||
|
min_seq_len=min_seq_len,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
logit_processor=logit_processor,
|
||||||
|
)
|
||||||
|
if fixed_output_length and output.shape[1] < seq_len:
|
||||||
|
return torch.cat(
|
||||||
|
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
elif generation_type == "top_p":
|
||||||
|
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
||||||
|
elif generation_type == "top_k":
|
||||||
|
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"generation_type has to be one of "
|
||||||
|
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_latent, image_embs = self._encode_image(image)
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
||||||
|
|
||||||
|
was_training = self.training
|
||||||
|
num_dims = len(text.shape)
|
||||||
|
|
||||||
|
if num_dims == 1:
|
||||||
|
text = text[None, :]
|
||||||
|
|
||||||
|
cur_len = text.shape[1]
|
||||||
|
self.eval()
|
||||||
|
out = text
|
||||||
|
|
||||||
|
while True:
|
||||||
|
x = out[:, -max_seq_len:]
|
||||||
|
cur_len = x.shape[1]
|
||||||
|
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
||||||
|
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
||||||
|
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
||||||
|
|
||||||
|
if mask.all():
|
||||||
|
if not fixed_output_length:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logits = logits[~mask, :]
|
||||||
|
filtered_logits = logit_processor(x[~mask, :], logits)
|
||||||
|
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
||||||
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
if (cur_len + 1 == seq_len):
|
||||||
|
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
||||||
|
else:
|
||||||
|
sample[~mask, :] = torch.multinomial(probs, 1)
|
||||||
|
|
||||||
|
out = torch.cat((out, sample), dim=-1)
|
||||||
|
|
||||||
|
cur_len += 1
|
||||||
|
|
||||||
|
if stopping_criteria(out, None):
|
||||||
|
break
|
||||||
|
|
||||||
|
if num_dims == 1:
|
||||||
|
out = out.squeeze(0)
|
||||||
|
|
||||||
|
self.train(was_training)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _generate_beamsearch(
|
||||||
|
self,
|
||||||
|
image_inputs,
|
||||||
|
pad_token_id=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
sot_token_id=None,
|
||||||
|
num_beams=6,
|
||||||
|
num_beam_groups=3,
|
||||||
|
min_seq_len=5,
|
||||||
|
stopping_criteria=None,
|
||||||
|
logit_processor=None,
|
||||||
|
logit_warper=None,
|
||||||
|
):
|
||||||
|
device = image_inputs.device
|
||||||
|
batch_size = image_inputs.shape[0]
|
||||||
|
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
||||||
|
image_latent, image_embs = self._encode_image(image_inputs)
|
||||||
|
|
||||||
|
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
||||||
|
input_ids = input_ids * sot_token_id
|
||||||
|
beam_scorer = BeamSearchScorer(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
|
device=device,
|
||||||
|
num_beam_groups=num_beam_groups,
|
||||||
|
)
|
||||||
|
# instantiate logits processors
|
||||||
|
logits_processor = (
|
||||||
|
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
||||||
|
if logit_processor is None
|
||||||
|
else logit_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
|
num_beams = beam_scorer.num_beams
|
||||||
|
num_beam_groups = beam_scorer.num_beam_groups
|
||||||
|
num_sub_beams = num_beams // num_beam_groups
|
||||||
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
|
beam_indices = None
|
||||||
|
|
||||||
|
if num_beams * batch_size != batch_beam_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||||
|
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||||
|
# the same group don't produce same tokens everytime.
|
||||||
|
beam_scores[:, ::num_sub_beams] = 0
|
||||||
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
# predicted tokens in cur_len step
|
||||||
|
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||||
|
|
||||||
|
# indices which will form the beams in the next time step
|
||||||
|
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# do one decoder step on all beams of all sentences in batch
|
||||||
|
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
||||||
|
outputs = self(
|
||||||
|
model_inputs['images'],
|
||||||
|
model_inputs['text'],
|
||||||
|
embed_cls=False,
|
||||||
|
image_latent=image_latent,
|
||||||
|
image_embs=image_embs
|
||||||
|
)
|
||||||
|
|
||||||
|
for beam_group_idx in range(num_beam_groups):
|
||||||
|
group_start_idx = beam_group_idx * num_sub_beams
|
||||||
|
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||||
|
group_size = group_end_idx - group_start_idx
|
||||||
|
|
||||||
|
# indices of beams of current group among all sentences in batch
|
||||||
|
batch_group_indices = []
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_group_indices.extend(
|
||||||
|
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||||
|
)
|
||||||
|
group_input_ids = input_ids[batch_group_indices]
|
||||||
|
|
||||||
|
# select outputs of beams of currentg group only
|
||||||
|
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
||||||
|
vocab_size = next_token_logits.shape[-1]
|
||||||
|
|
||||||
|
next_token_scores_processed = logits_processor(
|
||||||
|
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||||
|
)
|
||||||
|
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
||||||
|
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
||||||
|
|
||||||
|
# reshape for beam search
|
||||||
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||||
|
|
||||||
|
next_token_scores, next_tokens = torch.topk(
|
||||||
|
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
||||||
|
)
|
||||||
|
|
||||||
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
||||||
|
next_tokens = next_tokens % vocab_size
|
||||||
|
|
||||||
|
# stateless
|
||||||
|
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||||
|
beam_outputs = beam_scorer.process(
|
||||||
|
group_input_ids,
|
||||||
|
next_token_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
beam_indices=process_beam_indices,
|
||||||
|
)
|
||||||
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||||
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
|
beam_idx = beam_outputs["next_beam_indices"]
|
||||||
|
|
||||||
|
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||||
|
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||||
|
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||||
|
|
||||||
|
# (beam_idx // group_size) -> batch_idx
|
||||||
|
# (beam_idx % group_size) -> offset of idx inside the group
|
||||||
|
reordering_indices[batch_group_indices] = (
|
||||||
|
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
# increase cur_len
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
||||||
|
break
|
||||||
|
|
||||||
|
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||||
|
sequence_outputs = beam_scorer.finalize(
|
||||||
|
input_ids,
|
||||||
|
beam_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
max_length=stopping_criteria.max_length,
|
||||||
|
beam_indices=final_beam_indices,
|
||||||
|
)
|
||||||
|
return sequence_outputs['sequences']
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
|
return {
|
||||||
|
"text": input_ids,
|
||||||
|
"images": image_inputs,
|
||||||
|
"past_key_values": past,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||||
|
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||||
433
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
Normal file
433
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
# from turtle import forward
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||||
|
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
||||||
|
resize_pos_embed, get_cast_dtype
|
||||||
|
from .coca_model import CoCa
|
||||||
|
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||||
|
from .openai import load_openai_model
|
||||||
|
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
||||||
|
from .transform import image_transform, AugmentationCfg
|
||||||
|
from .tokenizer import HFTokenizer, SimpleTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
HF_HUB_PREFIX = 'hf-hub:'
|
||||||
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
||||||
|
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
||||||
|
|
||||||
|
|
||||||
|
def _natural_key(string_):
|
||||||
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||||
|
|
||||||
|
|
||||||
|
def _rescan_model_configs():
|
||||||
|
global _MODEL_CONFIGS
|
||||||
|
|
||||||
|
config_ext = ('.json',)
|
||||||
|
config_files = []
|
||||||
|
for config_path in _MODEL_CONFIG_PATHS:
|
||||||
|
if config_path.is_file() and config_path.suffix in config_ext:
|
||||||
|
config_files.append(config_path)
|
||||||
|
elif config_path.is_dir():
|
||||||
|
for ext in config_ext:
|
||||||
|
config_files.extend(config_path.glob(f'*{ext}'))
|
||||||
|
|
||||||
|
for cf in config_files:
|
||||||
|
with open(cf, 'r') as f:
|
||||||
|
model_cfg = json.load(f)
|
||||||
|
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
||||||
|
_MODEL_CONFIGS[cf.stem] = model_cfg
|
||||||
|
|
||||||
|
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
||||||
|
|
||||||
|
|
||||||
|
_rescan_model_configs() # initial populate of model config registry
|
||||||
|
|
||||||
|
|
||||||
|
def list_models():
|
||||||
|
""" enumerate available model architectures based on config files """
|
||||||
|
return list(_MODEL_CONFIGS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_config(path):
|
||||||
|
""" add model config path or file and update registry """
|
||||||
|
if not isinstance(path, Path):
|
||||||
|
path = Path(path)
|
||||||
|
_MODEL_CONFIG_PATHS.append(path)
|
||||||
|
_rescan_model_configs()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name):
|
||||||
|
if model_name in _MODEL_CONFIGS:
|
||||||
|
return deepcopy(_MODEL_CONFIGS[model_name])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
||||||
|
if model_name.startswith(HF_HUB_PREFIX):
|
||||||
|
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
||||||
|
else:
|
||||||
|
config = get_model_config(model_name)
|
||||||
|
tokenizer = HFTokenizer(
|
||||||
|
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
if next(iter(state_dict.items()))[0].startswith('module'):
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path, strict=True):
|
||||||
|
state_dict = load_state_dict(checkpoint_path)
|
||||||
|
# detect old format and make compatible with new format
|
||||||
|
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
||||||
|
state_dict = convert_to_custom_text_state_dict(state_dict)
|
||||||
|
resize_pos_embed(state_dict, model)
|
||||||
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||||
|
return incompatible_keys
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
model_name: str,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
precision: str = 'fp32',
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
jit: bool = False,
|
||||||
|
force_quick_gelu: bool = False,
|
||||||
|
force_custom_text: bool = False,
|
||||||
|
force_patch_dropout: Optional[float] = None,
|
||||||
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
pretrained_image: bool = False,
|
||||||
|
pretrained_hf: bool = True,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
output_dict: Optional[bool] = None,
|
||||||
|
require_pretrained: bool = False,
|
||||||
|
):
|
||||||
|
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
||||||
|
if has_hf_hub_prefix:
|
||||||
|
model_id = model_name[len(HF_HUB_PREFIX):]
|
||||||
|
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
||||||
|
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
pretrained_cfg = config['preprocess_cfg']
|
||||||
|
model_cfg = config['model_cfg']
|
||||||
|
else:
|
||||||
|
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
||||||
|
checkpoint_path = None
|
||||||
|
pretrained_cfg = {}
|
||||||
|
model_cfg = None
|
||||||
|
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
if pretrained and pretrained.lower() == 'openai':
|
||||||
|
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
||||||
|
model = load_openai_model(
|
||||||
|
model_name,
|
||||||
|
precision=precision,
|
||||||
|
device=device,
|
||||||
|
jit=jit,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# to always output dict even if it is clip
|
||||||
|
if output_dict and hasattr(model, "output_dict"):
|
||||||
|
model.output_dict = True
|
||||||
|
else:
|
||||||
|
model_cfg = model_cfg or get_model_config(model_name)
|
||||||
|
if model_cfg is not None:
|
||||||
|
logging.info(f'Loaded {model_name} model config.')
|
||||||
|
else:
|
||||||
|
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
||||||
|
raise RuntimeError(f'Model config for {model_name} not found.')
|
||||||
|
|
||||||
|
if force_quick_gelu:
|
||||||
|
# override for use of QuickGELU on non-OpenAI transformer models
|
||||||
|
model_cfg["quick_gelu"] = True
|
||||||
|
|
||||||
|
if force_patch_dropout is not None:
|
||||||
|
# override the default patch dropout value
|
||||||
|
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
||||||
|
|
||||||
|
if force_image_size is not None:
|
||||||
|
# override model config's image size
|
||||||
|
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
||||||
|
|
||||||
|
if pretrained_image:
|
||||||
|
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
||||||
|
# pretrained weight loading for timm models set via vision_cfg
|
||||||
|
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
||||||
|
else:
|
||||||
|
assert False, 'pretrained image towers currently only supported for timm models'
|
||||||
|
|
||||||
|
cast_dtype = get_cast_dtype(precision)
|
||||||
|
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
||||||
|
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
||||||
|
|
||||||
|
if custom_text:
|
||||||
|
if is_hf_model:
|
||||||
|
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
||||||
|
if "coca" in model_name:
|
||||||
|
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
||||||
|
else:
|
||||||
|
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
||||||
|
else:
|
||||||
|
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
||||||
|
|
||||||
|
pretrained_loaded = False
|
||||||
|
if pretrained:
|
||||||
|
checkpoint_path = ''
|
||||||
|
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
||||||
|
if pretrained_cfg:
|
||||||
|
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
||||||
|
elif os.path.exists(pretrained):
|
||||||
|
checkpoint_path = pretrained
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
||||||
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
else:
|
||||||
|
error_str = (
|
||||||
|
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
||||||
|
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
||||||
|
logging.warning(error_str)
|
||||||
|
raise RuntimeError(error_str)
|
||||||
|
pretrained_loaded = True
|
||||||
|
elif has_hf_hub_prefix:
|
||||||
|
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
||||||
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
pretrained_loaded = True
|
||||||
|
|
||||||
|
if require_pretrained and not pretrained_loaded:
|
||||||
|
# callers of create_model_from_pretrained always expect pretrained weights
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
||||||
|
|
||||||
|
model.to(device=device)
|
||||||
|
if precision in ("fp16", "bf16"):
|
||||||
|
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
|
||||||
|
|
||||||
|
# set image / mean metadata from pretrained_cfg if available, or use default
|
||||||
|
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
||||||
|
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
||||||
|
|
||||||
|
# to always output dict even if it is clip
|
||||||
|
if output_dict and hasattr(model, "output_dict"):
|
||||||
|
model.output_dict = True
|
||||||
|
|
||||||
|
if jit:
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_loss(args):
|
||||||
|
if args.distill:
|
||||||
|
return DistillClipLoss(
|
||||||
|
local_loss=args.local_loss,
|
||||||
|
gather_with_grad=args.gather_with_grad,
|
||||||
|
cache_labels=True,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
use_horovod=args.horovod,
|
||||||
|
)
|
||||||
|
elif "coca" in args.model.lower():
|
||||||
|
return CoCaLoss(
|
||||||
|
caption_loss_weight=args.coca_caption_loss_weight,
|
||||||
|
clip_loss_weight=args.coca_contrastive_loss_weight,
|
||||||
|
local_loss=args.local_loss,
|
||||||
|
gather_with_grad=args.gather_with_grad,
|
||||||
|
cache_labels=True,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
use_horovod=args.horovod,
|
||||||
|
)
|
||||||
|
return ClipLoss(
|
||||||
|
local_loss=args.local_loss,
|
||||||
|
gather_with_grad=args.gather_with_grad,
|
||||||
|
cache_labels=True,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
use_horovod=args.horovod,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, input_size):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.input_size, 1024),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(1024, 128),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(128, 64),
|
||||||
|
torch.nn.Dropout(0.1),
|
||||||
|
torch.nn.Linear(64, 16),
|
||||||
|
torch.nn.Linear(16, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
# class semantic_head(torch.nn.Module):
|
||||||
|
# def __init__(self, input_size):
|
||||||
|
# super().__init__()
|
||||||
|
# self.input_size = input_size # for ViT-L-14 is 1024
|
||||||
|
# self.seg_head = torch.nn.Sequential(
|
||||||
|
# torch.nn.Linear(input_size, 128),
|
||||||
|
# torch.nn.Dropout(0.2),
|
||||||
|
# torch.nn.Linear(128, 64),
|
||||||
|
# torch.nn.Dropout(0.1),
|
||||||
|
# torch.nn.Linear(64, 16),
|
||||||
|
# torch.nn.Linear(16, 1),
|
||||||
|
# )
|
||||||
|
# self.sigmoid = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# return self.sigmoid(self.seg_head(x))
|
||||||
|
|
||||||
|
def create_model_and_transforms(
|
||||||
|
model_name: str,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
precision: str = 'fp32',
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
jit: bool = False,
|
||||||
|
force_quick_gelu: bool = False,
|
||||||
|
force_custom_text: bool = False,
|
||||||
|
force_patch_dropout: Optional[float] = None,
|
||||||
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
pretrained_image: bool = False,
|
||||||
|
pretrained_hf: bool = True,
|
||||||
|
image_mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
image_std: Optional[Tuple[float, ...]] = None,
|
||||||
|
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
light_augmentation = False,
|
||||||
|
output_dict: Optional[bool] = None,
|
||||||
|
with_score_predictor: bool = False,
|
||||||
|
with_region_predictor: bool = False
|
||||||
|
):
|
||||||
|
model = create_model(
|
||||||
|
model_name,
|
||||||
|
pretrained,
|
||||||
|
precision=precision,
|
||||||
|
device=device,
|
||||||
|
jit=jit,
|
||||||
|
force_quick_gelu=force_quick_gelu,
|
||||||
|
force_custom_text=force_custom_text,
|
||||||
|
force_patch_dropout=force_patch_dropout,
|
||||||
|
force_image_size=force_image_size,
|
||||||
|
pretrained_image=pretrained_image,
|
||||||
|
pretrained_hf=pretrained_hf,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
output_dict=output_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
||||||
|
image_std = image_std or getattr(model.visual, 'image_std', None)
|
||||||
|
|
||||||
|
if with_score_predictor:
|
||||||
|
model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
||||||
|
|
||||||
|
if with_region_predictor:
|
||||||
|
# model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
||||||
|
model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
|
||||||
|
# preprocess_train = image_transform_region(
|
||||||
|
# model.visual.image_size,
|
||||||
|
# is_train=True,
|
||||||
|
# mean=image_mean,
|
||||||
|
# std=image_std
|
||||||
|
# )
|
||||||
|
# preprocess_val = image_transform_region(
|
||||||
|
# model.visual.image_size,
|
||||||
|
# is_train=False,
|
||||||
|
# mean=image_mean,
|
||||||
|
# std=image_std
|
||||||
|
# )
|
||||||
|
|
||||||
|
if light_augmentation:
|
||||||
|
preprocess_val = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=False,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std,
|
||||||
|
resize_longest_max=True,
|
||||||
|
)
|
||||||
|
preprocess_train = preprocess_val
|
||||||
|
else:
|
||||||
|
preprocess_train = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=True,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std
|
||||||
|
)
|
||||||
|
preprocess_val = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=False,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, preprocess_train, preprocess_val
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_from_pretrained(
|
||||||
|
model_name: str,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
precision: str = 'fp32',
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
jit: bool = False,
|
||||||
|
force_quick_gelu: bool = False,
|
||||||
|
force_custom_text: bool = False,
|
||||||
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
return_transform: bool = True,
|
||||||
|
image_mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
image_std: Optional[Tuple[float, ...]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
model = create_model(
|
||||||
|
model_name,
|
||||||
|
pretrained,
|
||||||
|
precision=precision,
|
||||||
|
device=device,
|
||||||
|
jit=jit,
|
||||||
|
force_quick_gelu=force_quick_gelu,
|
||||||
|
force_custom_text=force_custom_text,
|
||||||
|
force_image_size=force_image_size,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
require_pretrained=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_transform:
|
||||||
|
return model
|
||||||
|
|
||||||
|
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
||||||
|
image_std = image_std or getattr(model.visual, 'image_std', None)
|
||||||
|
preprocess = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=False,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, preprocess
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
# HF architecture dict:
|
||||||
|
arch_dict = {
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
||||||
|
"roberta": {
|
||||||
|
"config_names": {
|
||||||
|
"context_length": "max_position_embeddings",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"width": "hidden_size",
|
||||||
|
"heads": "num_attention_heads",
|
||||||
|
"layers": "num_hidden_layers",
|
||||||
|
"layer_attr": "layer",
|
||||||
|
"token_embeddings_attr": "embeddings"
|
||||||
|
},
|
||||||
|
"pooler": "mean_pooler",
|
||||||
|
},
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
||||||
|
"xlm-roberta": {
|
||||||
|
"config_names": {
|
||||||
|
"context_length": "max_position_embeddings",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"width": "hidden_size",
|
||||||
|
"heads": "num_attention_heads",
|
||||||
|
"layers": "num_hidden_layers",
|
||||||
|
"layer_attr": "layer",
|
||||||
|
"token_embeddings_attr": "embeddings"
|
||||||
|
},
|
||||||
|
"pooler": "mean_pooler",
|
||||||
|
},
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
||||||
|
"mt5": {
|
||||||
|
"config_names": {
|
||||||
|
# unlimited seqlen
|
||||||
|
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
||||||
|
"context_length": "",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"width": "d_model",
|
||||||
|
"heads": "num_heads",
|
||||||
|
"layers": "num_layers",
|
||||||
|
"layer_attr": "block",
|
||||||
|
"token_embeddings_attr": "embed_tokens"
|
||||||
|
},
|
||||||
|
"pooler": "mean_pooler",
|
||||||
|
},
|
||||||
|
}
|
||||||
176
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
Normal file
176
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
""" huggingface model adapter
|
||||||
|
|
||||||
|
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import TensorType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
||||||
|
BaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
except ImportError as e:
|
||||||
|
transformers = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainedConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from .hf_configs import arch_dict
|
||||||
|
|
||||||
|
|
||||||
|
# utils
|
||||||
|
def _camel2snake(s):
|
||||||
|
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: ?last - for gpt-like models
|
||||||
|
_POOLERS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_pooler(cls):
|
||||||
|
"""Decorator registering pooler class"""
|
||||||
|
_POOLERS[_camel2snake(cls.__name__)] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
@register_pooler
|
||||||
|
class MeanPooler(nn.Module):
|
||||||
|
"""Mean pooling"""
|
||||||
|
|
||||||
|
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||||
|
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
||||||
|
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@register_pooler
|
||||||
|
class MaxPooler(nn.Module):
|
||||||
|
"""Max pooling"""
|
||||||
|
|
||||||
|
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||||
|
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
||||||
|
return masked_output.max(1).values
|
||||||
|
|
||||||
|
|
||||||
|
@register_pooler
|
||||||
|
class ClsPooler(nn.Module):
|
||||||
|
"""CLS token pooling"""
|
||||||
|
|
||||||
|
def __init__(self, use_pooler_output=True):
|
||||||
|
super().__init__()
|
||||||
|
self.cls_token_position = 0
|
||||||
|
self.use_pooler_output = use_pooler_output
|
||||||
|
|
||||||
|
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||||
|
if (self.use_pooler_output and
|
||||||
|
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
||||||
|
(x.pooler_output is not None)
|
||||||
|
):
|
||||||
|
return x.pooler_output
|
||||||
|
|
||||||
|
return x.last_hidden_state[:, self.cls_token_position, :]
|
||||||
|
|
||||||
|
|
||||||
|
class HFTextEncoder(nn.Module):
|
||||||
|
"""HuggingFace model adapter"""
|
||||||
|
output_tokens: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
output_dim: int,
|
||||||
|
config: PretrainedConfig = None,
|
||||||
|
pooler_type: str = None,
|
||||||
|
proj: str = None,
|
||||||
|
pretrained: bool = True,
|
||||||
|
output_tokens: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# TODO: find better way to get this information
|
||||||
|
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
||||||
|
|
||||||
|
if transformers is None:
|
||||||
|
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
||||||
|
if config is None:
|
||||||
|
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
||||||
|
AutoModel.from_config, self.config)
|
||||||
|
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
||||||
|
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
||||||
|
self.transformer = create_func(model_args)
|
||||||
|
self.transformer = self.transformer.encoder
|
||||||
|
else:
|
||||||
|
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
self.transformer = AutoModel.from_config(config)
|
||||||
|
if pooler_type is None: # get default arch pooler
|
||||||
|
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
||||||
|
|
||||||
|
self.pooler = _POOLERS[pooler_type]()
|
||||||
|
|
||||||
|
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
||||||
|
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
elif proj == 'linear':
|
||||||
|
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
||||||
|
elif proj == 'mlp':
|
||||||
|
hidden_size = (d_model + output_dim) // 2
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(d_model, hidden_size, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_size, output_dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: TensorType):
|
||||||
|
attn_mask = (x != self.config.pad_token_id).long()
|
||||||
|
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
||||||
|
pooled_out = self.pooler(out, attn_mask)
|
||||||
|
projected = self.proj(pooled_out)
|
||||||
|
|
||||||
|
seq_len = out.last_hidden_state.shape[1]
|
||||||
|
tokens = (
|
||||||
|
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
||||||
|
if type(self.pooler) == ClsPooler
|
||||||
|
else out.last_hidden_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.output_tokens:
|
||||||
|
return projected, tokens
|
||||||
|
return projected
|
||||||
|
|
||||||
|
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||||
|
if not unlocked_layers: # full freezing
|
||||||
|
for n, p in self.transformer.named_parameters():
|
||||||
|
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||||
|
return
|
||||||
|
|
||||||
|
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
||||||
|
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
||||||
|
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
||||||
|
embeddings = getattr(
|
||||||
|
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
||||||
|
modules = [embeddings, *layer_list][:-unlocked_layers]
|
||||||
|
# freeze layers
|
||||||
|
for module in modules:
|
||||||
|
for n, p in module.named_parameters():
|
||||||
|
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.transformer.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
pass
|
||||||
270
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
Normal file
270
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch.distributed.nn
|
||||||
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
has_distributed = True
|
||||||
|
except ImportError:
|
||||||
|
has_distributed = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import horovod.torch as hvd
|
||||||
|
except ImportError:
|
||||||
|
hvd = None
|
||||||
|
|
||||||
|
|
||||||
|
def gather_features(
|
||||||
|
image_features,
|
||||||
|
text_features,
|
||||||
|
local_loss=False,
|
||||||
|
gather_with_grad=False,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
use_horovod=False
|
||||||
|
):
|
||||||
|
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
||||||
|
if use_horovod:
|
||||||
|
assert hvd is not None, 'Please install horovod'
|
||||||
|
if gather_with_grad:
|
||||||
|
all_image_features = hvd.allgather(image_features)
|
||||||
|
all_text_features = hvd.allgather(text_features)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
all_image_features = hvd.allgather(image_features)
|
||||||
|
all_text_features = hvd.allgather(text_features)
|
||||||
|
if not local_loss:
|
||||||
|
# ensure grads for local rank when all_* features don't have a gradient
|
||||||
|
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
||||||
|
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
||||||
|
gathered_image_features[rank] = image_features
|
||||||
|
gathered_text_features[rank] = text_features
|
||||||
|
all_image_features = torch.cat(gathered_image_features, dim=0)
|
||||||
|
all_text_features = torch.cat(gathered_text_features, dim=0)
|
||||||
|
else:
|
||||||
|
# We gather tensors from all gpus
|
||||||
|
if gather_with_grad:
|
||||||
|
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
||||||
|
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
||||||
|
else:
|
||||||
|
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
||||||
|
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
||||||
|
dist.all_gather(gathered_image_features, image_features)
|
||||||
|
dist.all_gather(gathered_text_features, text_features)
|
||||||
|
if not local_loss:
|
||||||
|
# ensure grads for local rank when all_* features don't have a gradient
|
||||||
|
gathered_image_features[rank] = image_features
|
||||||
|
gathered_text_features[rank] = text_features
|
||||||
|
all_image_features = torch.cat(gathered_image_features, dim=0)
|
||||||
|
all_text_features = torch.cat(gathered_text_features, dim=0)
|
||||||
|
|
||||||
|
return all_image_features, all_text_features
|
||||||
|
|
||||||
|
|
||||||
|
class ClipLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_loss=False,
|
||||||
|
gather_with_grad=False,
|
||||||
|
cache_labels=False,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
use_horovod=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.local_loss = local_loss
|
||||||
|
self.gather_with_grad = gather_with_grad
|
||||||
|
self.cache_labels = cache_labels
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
self.use_horovod = use_horovod
|
||||||
|
|
||||||
|
# cache state
|
||||||
|
self.prev_num_logits = 0
|
||||||
|
self.labels = {}
|
||||||
|
|
||||||
|
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
||||||
|
# calculated ground-truth and cache if enabled
|
||||||
|
if self.prev_num_logits != num_logits or device not in self.labels:
|
||||||
|
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
||||||
|
if self.world_size > 1 and self.local_loss:
|
||||||
|
labels = labels + num_logits * self.rank
|
||||||
|
if self.cache_labels:
|
||||||
|
self.labels[device] = labels
|
||||||
|
self.prev_num_logits = num_logits
|
||||||
|
else:
|
||||||
|
labels = self.labels[device]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def get_logits(self, image_features, text_features, logit_scale):
|
||||||
|
if self.world_size > 1:
|
||||||
|
all_image_features, all_text_features = gather_features(
|
||||||
|
image_features, text_features,
|
||||||
|
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
||||||
|
|
||||||
|
if self.local_loss:
|
||||||
|
logits_per_image = logit_scale * image_features @ all_text_features.T
|
||||||
|
logits_per_text = logit_scale * text_features @ all_image_features.T
|
||||||
|
else:
|
||||||
|
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
||||||
|
logits_per_text = logits_per_image.T
|
||||||
|
else:
|
||||||
|
logits_per_image = logit_scale * image_features @ text_features.T
|
||||||
|
logits_per_text = logit_scale * text_features @ image_features.T
|
||||||
|
|
||||||
|
return logits_per_image, logits_per_text
|
||||||
|
|
||||||
|
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
||||||
|
device = image_features.device
|
||||||
|
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
||||||
|
|
||||||
|
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
||||||
|
|
||||||
|
total_loss = (
|
||||||
|
F.cross_entropy(logits_per_image, labels) +
|
||||||
|
F.cross_entropy(logits_per_text, labels)
|
||||||
|
) / 2
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
class PreferenceLoss(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, logits_per_image, num_images, labels):
|
||||||
|
|
||||||
|
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
||||||
|
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
|
||||||
|
|
||||||
|
ce_loss = F.cross_entropy(paired_logits, labels)
|
||||||
|
return ce_loss
|
||||||
|
|
||||||
|
class HPSLoss(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, text_logits, labels):
|
||||||
|
|
||||||
|
device = text_logits.device
|
||||||
|
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
|
||||||
|
label_0, label_1 = labels.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
|
||||||
|
text_0_logits = text_0_logits[index, index]
|
||||||
|
text_1_logits = text_1_logits[index, index]
|
||||||
|
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
||||||
|
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
|
||||||
|
text_1_labels = text_0_labels + 1
|
||||||
|
|
||||||
|
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
|
||||||
|
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
|
||||||
|
|
||||||
|
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
|
||||||
|
|
||||||
|
# absolute_example_weight = 1 / num_per_prompt
|
||||||
|
# denominator = absolute_example_weight.sum()
|
||||||
|
# weight_per_example = absolute_example_weight / denominator
|
||||||
|
# text_loss *= weight_per_example
|
||||||
|
|
||||||
|
text_loss = text_loss.sum()
|
||||||
|
return text_loss
|
||||||
|
|
||||||
|
class RankingLoss(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, logits_per_image, num_images, labels, margin = 1.0):
|
||||||
|
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
||||||
|
label_list = [label for label in labels.split(num_images.tolist())]
|
||||||
|
# ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
|
||||||
|
|
||||||
|
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
|
||||||
|
padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
|
||||||
|
|
||||||
|
# regulized_logits = torch.log(torch.sigmoid(paired_logits))
|
||||||
|
|
||||||
|
diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
||||||
|
# diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
||||||
|
# diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
|
||||||
|
diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
|
||||||
|
mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
|
||||||
|
|
||||||
|
loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
class CoCaLoss(ClipLoss):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
caption_loss_weight,
|
||||||
|
clip_loss_weight,
|
||||||
|
pad_id=0, # pad_token for open_clip custom tokenizer
|
||||||
|
local_loss=False,
|
||||||
|
gather_with_grad=False,
|
||||||
|
cache_labels=False,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
use_horovod=False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
local_loss=local_loss,
|
||||||
|
gather_with_grad=gather_with_grad,
|
||||||
|
cache_labels=cache_labels,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
use_horovod=use_horovod
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_loss_weight = clip_loss_weight
|
||||||
|
self.caption_loss_weight = caption_loss_weight
|
||||||
|
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
||||||
|
|
||||||
|
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
||||||
|
clip_loss = super().forward(image_features, text_features, logit_scale)
|
||||||
|
clip_loss = self.clip_loss_weight * clip_loss
|
||||||
|
|
||||||
|
caption_loss = self.caption_loss(
|
||||||
|
logits.permute(0, 2, 1),
|
||||||
|
labels,
|
||||||
|
)
|
||||||
|
caption_loss = caption_loss * self.caption_loss_weight
|
||||||
|
|
||||||
|
if output_dict:
|
||||||
|
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
||||||
|
|
||||||
|
return clip_loss, caption_loss
|
||||||
|
|
||||||
|
|
||||||
|
class DistillClipLoss(ClipLoss):
|
||||||
|
|
||||||
|
def dist_loss(self, teacher_logits, student_logits):
|
||||||
|
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_features,
|
||||||
|
text_features,
|
||||||
|
logit_scale,
|
||||||
|
dist_image_features,
|
||||||
|
dist_text_features,
|
||||||
|
dist_logit_scale,
|
||||||
|
output_dict=False,
|
||||||
|
):
|
||||||
|
logits_per_image, logits_per_text = \
|
||||||
|
self.get_logits(image_features, text_features, logit_scale)
|
||||||
|
|
||||||
|
dist_logits_per_image, dist_logits_per_text = \
|
||||||
|
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
||||||
|
|
||||||
|
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
||||||
|
|
||||||
|
contrastive_loss = (
|
||||||
|
F.cross_entropy(logits_per_image, labels) +
|
||||||
|
F.cross_entropy(logits_per_text, labels)
|
||||||
|
) / 2
|
||||||
|
|
||||||
|
distill_loss = (
|
||||||
|
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
||||||
|
self.dist_loss(dist_logits_per_text, logits_per_text)
|
||||||
|
) / 2
|
||||||
|
|
||||||
|
if output_dict:
|
||||||
|
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
||||||
|
|
||||||
|
return contrastive_loss, distill_loss
|
||||||
461
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
Normal file
461
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
""" CLIP Model
|
||||||
|
|
||||||
|
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from .hf_model import HFTextEncoder
|
||||||
|
from .modified_resnet import ModifiedResNet
|
||||||
|
from .timm_model import TimmModel
|
||||||
|
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
||||||
|
from .utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPVisionCfg:
|
||||||
|
layers: Union[Tuple[int, int, int, int], int] = 12
|
||||||
|
width: int = 768
|
||||||
|
head_width: int = 64
|
||||||
|
mlp_ratio: float = 4.0
|
||||||
|
patch_size: int = 16
|
||||||
|
image_size: Union[Tuple[int, int], int] = 224
|
||||||
|
ls_init_value: Optional[float] = None # layer scale initial value
|
||||||
|
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
||||||
|
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
||||||
|
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
||||||
|
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
||||||
|
n_queries: int = 256 # n_queries for attentional pooler
|
||||||
|
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
||||||
|
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
||||||
|
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
||||||
|
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
||||||
|
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
||||||
|
timm_proj_bias: bool = False # enable bias final projection
|
||||||
|
timm_drop: float = 0. # head dropout
|
||||||
|
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
||||||
|
output_tokens: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPTextCfg:
|
||||||
|
context_length: int = 77
|
||||||
|
vocab_size: int = 49408
|
||||||
|
width: int = 512
|
||||||
|
heads: int = 8
|
||||||
|
layers: int = 12
|
||||||
|
ls_init_value: Optional[float] = None # layer scale initial value
|
||||||
|
hf_model_name: str = None
|
||||||
|
hf_tokenizer_name: str = None
|
||||||
|
hf_model_pretrained: bool = True
|
||||||
|
proj: str = 'mlp'
|
||||||
|
pooler_type: str = 'mean_pooler'
|
||||||
|
embed_cls: bool = False
|
||||||
|
pad_id: int = 0
|
||||||
|
output_tokens: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_cast_dtype(precision: str):
|
||||||
|
cast_dtype = None
|
||||||
|
if precision == 'bf16':
|
||||||
|
cast_dtype = torch.bfloat16
|
||||||
|
elif precision == 'fp16':
|
||||||
|
cast_dtype = torch.float16
|
||||||
|
return cast_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _build_vision_tower(
|
||||||
|
embed_dim: int,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None
|
||||||
|
):
|
||||||
|
if isinstance(vision_cfg, dict):
|
||||||
|
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
||||||
|
|
||||||
|
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
||||||
|
# memory efficient in recent PyTorch releases (>= 1.10).
|
||||||
|
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
||||||
|
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||||
|
|
||||||
|
if vision_cfg.timm_model_name:
|
||||||
|
visual = TimmModel(
|
||||||
|
vision_cfg.timm_model_name,
|
||||||
|
pretrained=vision_cfg.timm_model_pretrained,
|
||||||
|
pool=vision_cfg.timm_pool,
|
||||||
|
proj=vision_cfg.timm_proj,
|
||||||
|
proj_bias=vision_cfg.timm_proj_bias,
|
||||||
|
drop=vision_cfg.timm_drop,
|
||||||
|
drop_path=vision_cfg.timm_drop_path,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
image_size=vision_cfg.image_size,
|
||||||
|
)
|
||||||
|
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
||||||
|
elif isinstance(vision_cfg.layers, (tuple, list)):
|
||||||
|
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
||||||
|
visual = ModifiedResNet(
|
||||||
|
layers=vision_cfg.layers,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
heads=vision_heads,
|
||||||
|
image_size=vision_cfg.image_size,
|
||||||
|
width=vision_cfg.width,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vision_heads = vision_cfg.width // vision_cfg.head_width
|
||||||
|
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||||
|
visual = VisionTransformer(
|
||||||
|
image_size=vision_cfg.image_size,
|
||||||
|
patch_size=vision_cfg.patch_size,
|
||||||
|
width=vision_cfg.width,
|
||||||
|
layers=vision_cfg.layers,
|
||||||
|
heads=vision_heads,
|
||||||
|
mlp_ratio=vision_cfg.mlp_ratio,
|
||||||
|
ls_init_value=vision_cfg.ls_init_value,
|
||||||
|
patch_dropout=vision_cfg.patch_dropout,
|
||||||
|
input_patchnorm=vision_cfg.input_patchnorm,
|
||||||
|
global_average_pool=vision_cfg.global_average_pool,
|
||||||
|
attentional_pool=vision_cfg.attentional_pool,
|
||||||
|
n_queries=vision_cfg.n_queries,
|
||||||
|
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
||||||
|
output_tokens=vision_cfg.output_tokens,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return visual
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_tower(
|
||||||
|
embed_dim: int,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if isinstance(text_cfg, dict):
|
||||||
|
text_cfg = CLIPTextCfg(**text_cfg)
|
||||||
|
|
||||||
|
if text_cfg.hf_model_name:
|
||||||
|
text = HFTextEncoder(
|
||||||
|
text_cfg.hf_model_name,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
proj=text_cfg.proj,
|
||||||
|
pooler_type=text_cfg.pooler_type,
|
||||||
|
pretrained=text_cfg.hf_model_pretrained,
|
||||||
|
output_tokens=text_cfg.output_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||||
|
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||||
|
|
||||||
|
text = TextTransformer(
|
||||||
|
context_length=text_cfg.context_length,
|
||||||
|
vocab_size=text_cfg.vocab_size,
|
||||||
|
width=text_cfg.width,
|
||||||
|
heads=text_cfg.heads,
|
||||||
|
layers=text_cfg.layers,
|
||||||
|
ls_init_value=text_cfg.ls_init_value,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
embed_cls=text_cfg.embed_cls,
|
||||||
|
output_tokens=text_cfg.output_tokens,
|
||||||
|
pad_id=text_cfg.pad_id,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class CLIP(nn.Module):
|
||||||
|
output_dict: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
output_dict: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dict = output_dict
|
||||||
|
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
||||||
|
|
||||||
|
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
||||||
|
self.transformer = text.transformer
|
||||||
|
self.vocab_size = text.vocab_size
|
||||||
|
self.token_embedding = text.token_embedding
|
||||||
|
self.positional_embedding = text.positional_embedding
|
||||||
|
self.ln_final = text.ln_final
|
||||||
|
self.text_projection = text.text_projection
|
||||||
|
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
||||||
|
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
|
||||||
|
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||||
|
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||||
|
|
||||||
|
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||||
|
locked_layers = []
|
||||||
|
locked_layers.append(self.token_embedding)
|
||||||
|
self.positional_embedding.requires_grad = False
|
||||||
|
if unlocked_layers > 0:
|
||||||
|
locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
|
||||||
|
else:
|
||||||
|
locked_layers.append(self.transformer)
|
||||||
|
locked_layers.append(self.ln_final)
|
||||||
|
self.text_projection.requires_grad = False
|
||||||
|
|
||||||
|
# freeze layers
|
||||||
|
for module in locked_layers:
|
||||||
|
for n, p in module.named_parameters():
|
||||||
|
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.visual.set_grad_checkpointing(enable)
|
||||||
|
self.transformer.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def encode_image(self, image, normalize: bool = False):
|
||||||
|
features = self.visual(image)
|
||||||
|
return F.normalize(features, dim=-1) if normalize else features
|
||||||
|
|
||||||
|
def encode_text(self, text, normalize: bool = False):
|
||||||
|
cast_dtype = self.transformer.get_cast_dtype()
|
||||||
|
|
||||||
|
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||||
|
|
||||||
|
x = x + self.positional_embedding.to(cast_dtype)
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
||||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||||
|
return F.normalize(x, dim=-1) if normalize else x
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
image_features = self.encode_image(image, normalize=True)
|
||||||
|
text_features = self.encode_text(text, normalize=True)
|
||||||
|
if self.output_dict:
|
||||||
|
return {
|
||||||
|
"image_features": image_features,
|
||||||
|
"text_features": text_features,
|
||||||
|
"logit_scale": self.logit_scale.exp()
|
||||||
|
}
|
||||||
|
return image_features, text_features, self.logit_scale.exp()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTextCLIP(nn.Module):
|
||||||
|
output_dict: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
output_dict: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dict = output_dict
|
||||||
|
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
||||||
|
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
|
||||||
|
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||||
|
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||||
|
|
||||||
|
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||||
|
self.text.lock(unlocked_layers, freeze_layer_norm)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.visual.set_grad_checkpointing(enable)
|
||||||
|
self.text.set_grad_checkpointing(enable)
|
||||||
|
|
||||||
|
def encode_image(self, image, normalize: bool = False):
|
||||||
|
features = self.visual(image)
|
||||||
|
return F.normalize(features, dim=-1) if normalize else features
|
||||||
|
|
||||||
|
def encode_text(self, text, normalize: bool = False):
|
||||||
|
features = self.text(text)
|
||||||
|
return F.normalize(features, dim=-1) if normalize else features
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
image_features = self.encode_image(image, normalize=True)
|
||||||
|
text_features = self.encode_text(text, normalize=True)
|
||||||
|
if self.output_dict:
|
||||||
|
return {
|
||||||
|
"image_features": image_features,
|
||||||
|
"text_features": text_features,
|
||||||
|
"logit_scale": self.logit_scale.exp()
|
||||||
|
}
|
||||||
|
return image_features, text_features, self.logit_scale.exp()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
||||||
|
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
||||||
|
|
||||||
|
def _convert_weights(l):
|
||||||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||||
|
l.weight.data = l.weight.data.to(dtype)
|
||||||
|
if l.bias is not None:
|
||||||
|
l.bias.data = l.bias.data.to(dtype)
|
||||||
|
|
||||||
|
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||||
|
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||||
|
tensor = getattr(l, attr)
|
||||||
|
if tensor is not None:
|
||||||
|
tensor.data = tensor.data.to(dtype)
|
||||||
|
|
||||||
|
for name in ["text_projection", "proj"]:
|
||||||
|
if hasattr(l, name):
|
||||||
|
attr = getattr(l, name)
|
||||||
|
if attr is not None:
|
||||||
|
attr.data = attr.data.to(dtype)
|
||||||
|
|
||||||
|
model.apply(_convert_weights)
|
||||||
|
|
||||||
|
|
||||||
|
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
||||||
|
|
||||||
|
|
||||||
|
# used to maintain checkpoint compatibility
|
||||||
|
def convert_to_custom_text_state_dict(state_dict: dict):
|
||||||
|
if 'text_projection' in state_dict:
|
||||||
|
# old format state_dict, move text tower -> .text
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if any(k.startswith(p) for p in (
|
||||||
|
'text_projection',
|
||||||
|
'positional_embedding',
|
||||||
|
'token_embedding',
|
||||||
|
'transformer',
|
||||||
|
'ln_final',
|
||||||
|
)):
|
||||||
|
k = 'text.' + k
|
||||||
|
new_state_dict[k] = v
|
||||||
|
return new_state_dict
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_from_openai_state_dict(
|
||||||
|
state_dict: dict,
|
||||||
|
quick_gelu=True,
|
||||||
|
cast_dtype=torch.float16,
|
||||||
|
):
|
||||||
|
vit = "visual.proj" in state_dict
|
||||||
|
|
||||||
|
if vit:
|
||||||
|
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||||
|
vision_layers = len(
|
||||||
|
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||||
|
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||||
|
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||||
|
image_size = vision_patch_size * grid_size
|
||||||
|
else:
|
||||||
|
counts: list = [
|
||||||
|
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||||
|
vision_layers = tuple(counts)
|
||||||
|
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||||
|
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||||
|
vision_patch_size = None
|
||||||
|
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||||
|
image_size = output_width * 32
|
||||||
|
|
||||||
|
embed_dim = state_dict["text_projection"].shape[1]
|
||||||
|
context_length = state_dict["positional_embedding"].shape[0]
|
||||||
|
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||||
|
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||||
|
transformer_heads = transformer_width // 64
|
||||||
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||||
|
|
||||||
|
vision_cfg = CLIPVisionCfg(
|
||||||
|
layers=vision_layers,
|
||||||
|
width=vision_width,
|
||||||
|
patch_size=vision_patch_size,
|
||||||
|
image_size=image_size,
|
||||||
|
)
|
||||||
|
text_cfg = CLIPTextCfg(
|
||||||
|
context_length=context_length,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
width=transformer_width,
|
||||||
|
heads=transformer_heads,
|
||||||
|
layers=transformer_layers,
|
||||||
|
)
|
||||||
|
model = CLIP(
|
||||||
|
embed_dim,
|
||||||
|
vision_cfg=vision_cfg,
|
||||||
|
text_cfg=text_cfg,
|
||||||
|
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||||
|
state_dict.pop(key, None)
|
||||||
|
|
||||||
|
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
||||||
|
model.eval()
|
||||||
|
image_size = model.visual.image_size
|
||||||
|
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
||||||
|
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
||||||
|
model = torch.jit.trace_module(
|
||||||
|
model,
|
||||||
|
inputs=dict(
|
||||||
|
forward=(example_images, example_text),
|
||||||
|
encode_text=(example_text,),
|
||||||
|
encode_image=(example_images,)
|
||||||
|
))
|
||||||
|
model.visual.image_size = image_size
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
||||||
|
# Rescale the grid of position embeddings when loading from state_dict
|
||||||
|
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
||||||
|
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
||||||
|
return
|
||||||
|
grid_size = to_2tuple(model.visual.grid_size)
|
||||||
|
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
||||||
|
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
||||||
|
if new_seq_len == old_pos_embed.shape[0]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if extra_tokens:
|
||||||
|
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
||||||
|
else:
|
||||||
|
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
||||||
|
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
||||||
|
|
||||||
|
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
||||||
|
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
pos_emb_img = F.interpolate(
|
||||||
|
pos_emb_img,
|
||||||
|
size=grid_size,
|
||||||
|
mode=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
||||||
|
if pos_emb_tok is not None:
|
||||||
|
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
||||||
|
else:
|
||||||
|
new_pos_embed = pos_emb_img
|
||||||
|
state_dict['visual.positional_embedding'] = new_pos_embed
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"embed_dim": 1024,
|
||||||
|
"vision_cfg": {
|
||||||
|
"image_size": 224,
|
||||||
|
"layers": 32,
|
||||||
|
"width": 1280,
|
||||||
|
"head_width": 80,
|
||||||
|
"patch_size": 14
|
||||||
|
},
|
||||||
|
"text_cfg": {
|
||||||
|
"context_length": 77,
|
||||||
|
"vocab_size": 49408,
|
||||||
|
"width": 1024,
|
||||||
|
"heads": 16,
|
||||||
|
"layers": 24
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .utils import freeze_batch_norm_2d
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
self.act1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
self.act2 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.act3 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.downsample = None
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||||
|
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||||
|
self.downsample = nn.Sequential(OrderedDict([
|
||||||
|
("-1", nn.AvgPool2d(stride)),
|
||||||
|
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||||
|
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.act1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.act2(self.bn2(self.conv2(out)))
|
||||||
|
out = self.avgpool(out)
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.act3(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool2d(nn.Module):
|
||||||
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||||
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||||
|
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||||
|
x, _ = F.multi_head_attention_forward(
|
||||||
|
query=x, key=x, value=x,
|
||||||
|
embed_dim_to_check=x.shape[-1],
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
in_proj_weight=None,
|
||||||
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||||
|
bias_k=None,
|
||||||
|
bias_v=None,
|
||||||
|
add_zero_attn=False,
|
||||||
|
dropout_p=0.,
|
||||||
|
out_proj_weight=self.c_proj.weight,
|
||||||
|
out_proj_bias=self.c_proj.bias,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
training=self.training,
|
||||||
|
need_weights=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return x[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedResNet(nn.Module):
|
||||||
|
"""
|
||||||
|
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||||
|
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||||
|
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||||
|
- The final pooling layer is a QKV attention instead of an average pool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
# the 3-layer stem
|
||||||
|
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||||
|
self.act1 = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||||
|
self.act2 = nn.ReLU(inplace=True)
|
||||||
|
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(width)
|
||||||
|
self.act3 = nn.ReLU(inplace=True)
|
||||||
|
self.avgpool = nn.AvgPool2d(2)
|
||||||
|
|
||||||
|
# residual layers
|
||||||
|
self._inplanes = width # this is a *mutable* variable used during construction
|
||||||
|
self.layer1 = self._make_layer(width, layers[0])
|
||||||
|
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||||
|
|
||||||
|
embed_dim = width * 32 # the ResNet feature dimension
|
||||||
|
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
||||||
|
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def _make_layer(self, planes, blocks, stride=1):
|
||||||
|
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||||
|
|
||||||
|
self._inplanes = planes * Bottleneck.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(Bottleneck(self._inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
if self.attnpool is not None:
|
||||||
|
std = self.attnpool.c_proj.in_features ** -0.5
|
||||||
|
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
||||||
|
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
||||||
|
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
||||||
|
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
||||||
|
|
||||||
|
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
||||||
|
for name, param in resnet_block.named_parameters():
|
||||||
|
if name.endswith("bn3.weight"):
|
||||||
|
nn.init.zeros_(param)
|
||||||
|
|
||||||
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if freeze_bn_stats:
|
||||||
|
freeze_batch_norm_2d(self)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
# FIXME support for non-transformer
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stem(self, x):
|
||||||
|
x = self.act1(self.bn1(self.conv1(x)))
|
||||||
|
x = self.act2(self.bn2(self.conv2(x)))
|
||||||
|
x = self.act3(self.bn3(self.conv3(x)))
|
||||||
|
x = self.avgpool(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.attnpool(x)
|
||||||
|
|
||||||
|
return x
|
||||||
144
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
Normal file
144
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
""" OpenAI pretrained model functions
|
||||||
|
|
||||||
|
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
||||||
|
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
||||||
|
|
||||||
|
__all__ = ["list_openai_models", "load_openai_model"]
|
||||||
|
|
||||||
|
|
||||||
|
def list_openai_models() -> List[str]:
|
||||||
|
"""Returns the names of available CLIP models"""
|
||||||
|
return list_pretrained_models_by_tag('openai')
|
||||||
|
|
||||||
|
|
||||||
|
def load_openai_model(
|
||||||
|
name: str,
|
||||||
|
precision: Optional[str] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
jit: bool = True,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Load a CLIP model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||||
|
precision: str
|
||||||
|
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
||||||
|
device : Union[str, torch.device]
|
||||||
|
The device to put the loaded model
|
||||||
|
jit : bool
|
||||||
|
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||||
|
cache_dir : Optional[str]
|
||||||
|
The directory to cache the downloaded model weights
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : torch.nn.Module
|
||||||
|
The CLIP model
|
||||||
|
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||||
|
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if precision is None:
|
||||||
|
precision = 'fp32' if device == 'cpu' else 'fp16'
|
||||||
|
|
||||||
|
if get_pretrained_url(name, 'openai'):
|
||||||
|
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
||||||
|
elif os.path.isfile(name):
|
||||||
|
model_path = name
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# loading JIT archive
|
||||||
|
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||||
|
state_dict = None
|
||||||
|
except RuntimeError:
|
||||||
|
# loading saved state dict
|
||||||
|
if jit:
|
||||||
|
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||||
|
jit = False
|
||||||
|
state_dict = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
|
if not jit:
|
||||||
|
# Build a non-jit model from the OpenAI jitted model state dict
|
||||||
|
cast_dtype = get_cast_dtype(precision)
|
||||||
|
try:
|
||||||
|
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
||||||
|
except KeyError:
|
||||||
|
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
||||||
|
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
||||||
|
|
||||||
|
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
||||||
|
model = model.to(device)
|
||||||
|
if precision.startswith('amp') or precision == 'fp32':
|
||||||
|
model.float()
|
||||||
|
elif precision == 'bf16':
|
||||||
|
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# patch the device names
|
||||||
|
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||||
|
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||||
|
|
||||||
|
def patch_device(module):
|
||||||
|
try:
|
||||||
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||||
|
except RuntimeError:
|
||||||
|
graphs = []
|
||||||
|
|
||||||
|
if hasattr(module, "forward1"):
|
||||||
|
graphs.append(module.forward1.graph)
|
||||||
|
|
||||||
|
for graph in graphs:
|
||||||
|
for node in graph.findAllNodes("prim::Constant"):
|
||||||
|
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||||
|
node.copyAttributes(device_node)
|
||||||
|
|
||||||
|
model.apply(patch_device)
|
||||||
|
patch_device(model.encode_image)
|
||||||
|
patch_device(model.encode_text)
|
||||||
|
|
||||||
|
# patch dtype to float32 (typically for CPU)
|
||||||
|
if precision == 'fp32':
|
||||||
|
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||||
|
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||||
|
float_node = float_input.node()
|
||||||
|
|
||||||
|
def patch_float(module):
|
||||||
|
try:
|
||||||
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||||
|
except RuntimeError:
|
||||||
|
graphs = []
|
||||||
|
|
||||||
|
if hasattr(module, "forward1"):
|
||||||
|
graphs.append(module.forward1.graph)
|
||||||
|
|
||||||
|
for graph in graphs:
|
||||||
|
for node in graph.findAllNodes("aten::to"):
|
||||||
|
inputs = list(node.inputs())
|
||||||
|
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||||
|
if inputs[i].node()["value"] == 5:
|
||||||
|
inputs[i].node().copyAttributes(float_node)
|
||||||
|
|
||||||
|
model.apply(patch_float)
|
||||||
|
patch_float(model.encode_image)
|
||||||
|
patch_float(model.encode_text)
|
||||||
|
model.float()
|
||||||
|
|
||||||
|
# ensure image_size attr available at consistent location for both jit and non-jit
|
||||||
|
model.visual.image_size = model.input_resolution.item()
|
||||||
|
return model
|
||||||
376
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
Normal file
376
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .version import __version__
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
|
||||||
|
_has_hf_hub = True
|
||||||
|
except ImportError:
|
||||||
|
hf_hub_download = None
|
||||||
|
_has_hf_hub = False
|
||||||
|
|
||||||
|
|
||||||
|
def _pcfg(url='', hf_hub='', mean=None, std=None):
|
||||||
|
return dict(
|
||||||
|
url=url,
|
||||||
|
hf_hub=hf_hub,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_RN50 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
||||||
|
cc12m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50_quickgelu = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
||||||
|
cc12m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN101 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN101_quickgelu = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50x4 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50x16 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50x64 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB32 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
||||||
|
laion2b_e16=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
||||||
|
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB32_quickgelu = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB16 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
||||||
|
# laion400m_32k=_pcfg(
|
||||||
|
# url="",
|
||||||
|
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
# laion400m_64k=_pcfg(
|
||||||
|
# url="",
|
||||||
|
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB16_PLUS_240 = dict(
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITL14 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
||||||
|
laion2b_s32b_b82k=_pcfg(
|
||||||
|
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITL14_336 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITH14 = dict(
|
||||||
|
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITg14 = dict(
|
||||||
|
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
||||||
|
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITbigG14 = dict(
|
||||||
|
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_robertaViTB32 = dict(
|
||||||
|
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_xlmRobertaBaseViTB32 = dict(
|
||||||
|
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_xlmRobertaLargeFrozenViTH14 = dict(
|
||||||
|
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_base = dict(
|
||||||
|
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_base_w = dict(
|
||||||
|
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
|
||||||
|
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
|
||||||
|
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_base_w_320 = dict(
|
||||||
|
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
|
||||||
|
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_large_d = dict(
|
||||||
|
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_large_d_320 = dict(
|
||||||
|
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
|
||||||
|
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_xxlarge = dict(
|
||||||
|
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
|
||||||
|
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
|
||||||
|
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_coca_VITB32 = dict(
|
||||||
|
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
|
||||||
|
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
|
||||||
|
)
|
||||||
|
|
||||||
|
_coca_VITL14 = dict(
|
||||||
|
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
|
||||||
|
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PRETRAINED = {
|
||||||
|
"RN50": _RN50,
|
||||||
|
"RN50-quickgelu": _RN50_quickgelu,
|
||||||
|
"RN101": _RN101,
|
||||||
|
"RN101-quickgelu": _RN101_quickgelu,
|
||||||
|
"RN50x4": _RN50x4,
|
||||||
|
"RN50x16": _RN50x16,
|
||||||
|
"RN50x64": _RN50x64,
|
||||||
|
"ViT-B-32": _VITB32,
|
||||||
|
"ViT-B-32-quickgelu": _VITB32_quickgelu,
|
||||||
|
"ViT-B-16": _VITB16,
|
||||||
|
"ViT-B-16-plus-240": _VITB16_PLUS_240,
|
||||||
|
"ViT-L-14": _VITL14,
|
||||||
|
"ViT-L-14-336": _VITL14_336,
|
||||||
|
"ViT-H-14": _VITH14,
|
||||||
|
"ViT-g-14": _VITg14,
|
||||||
|
"ViT-bigG-14": _VITbigG14,
|
||||||
|
"roberta-ViT-B-32": _robertaViTB32,
|
||||||
|
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
|
||||||
|
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
|
||||||
|
"convnext_base": _convnext_base,
|
||||||
|
"convnext_base_w": _convnext_base_w,
|
||||||
|
"convnext_base_w_320": _convnext_base_w_320,
|
||||||
|
"convnext_large_d": _convnext_large_d,
|
||||||
|
"convnext_large_d_320": _convnext_large_d_320,
|
||||||
|
"convnext_xxlarge": _convnext_xxlarge,
|
||||||
|
"coca_ViT-B-32": _coca_VITB32,
|
||||||
|
"coca_ViT-L-14": _coca_VITL14,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_tag(tag: str):
|
||||||
|
# normalize pretrained tags
|
||||||
|
return tag.lower().replace('-', '_')
|
||||||
|
|
||||||
|
|
||||||
|
def list_pretrained(as_str: bool = False):
|
||||||
|
""" returns list of pretrained models
|
||||||
|
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
||||||
|
"""
|
||||||
|
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
||||||
|
|
||||||
|
|
||||||
|
def list_pretrained_models_by_tag(tag: str):
|
||||||
|
""" return all models having the specified pretrain tag """
|
||||||
|
models = []
|
||||||
|
tag = _clean_tag(tag)
|
||||||
|
for k in _PRETRAINED.keys():
|
||||||
|
if tag in _PRETRAINED[k]:
|
||||||
|
models.append(k)
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def list_pretrained_tags_by_model(model: str):
|
||||||
|
""" return all pretrain tags for the specified model architecture """
|
||||||
|
tags = []
|
||||||
|
if model in _PRETRAINED:
|
||||||
|
tags.extend(_PRETRAINED[model].keys())
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def is_pretrained_cfg(model: str, tag: str):
|
||||||
|
if model not in _PRETRAINED:
|
||||||
|
return False
|
||||||
|
return _clean_tag(tag) in _PRETRAINED[model]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_cfg(model: str, tag: str):
|
||||||
|
if model not in _PRETRAINED:
|
||||||
|
return {}
|
||||||
|
model_pretrained = _PRETRAINED[model]
|
||||||
|
return model_pretrained.get(_clean_tag(tag), {})
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_url(model: str, tag: str):
|
||||||
|
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
||||||
|
return cfg.get('url', '')
|
||||||
|
|
||||||
|
|
||||||
|
def download_pretrained_from_url(
|
||||||
|
url: str,
|
||||||
|
cache_dir: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.expanduser("~/.cache/clip")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
|
||||||
|
if 'openaipublic' in url:
|
||||||
|
expected_sha256 = url.split("/")[-2]
|
||||||
|
elif 'mlfoundations' in url:
|
||||||
|
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
||||||
|
else:
|
||||||
|
expected_sha256 = ''
|
||||||
|
|
||||||
|
download_target = os.path.join(cache_dir, filename)
|
||||||
|
|
||||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
if os.path.isfile(download_target):
|
||||||
|
if expected_sha256:
|
||||||
|
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
||||||
|
return download_target
|
||||||
|
else:
|
||||||
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
else:
|
||||||
|
return download_target
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
|
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
||||||
|
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||||
|
|
||||||
|
return download_target
|
||||||
|
|
||||||
|
|
||||||
|
def has_hf_hub(necessary=False):
|
||||||
|
if not _has_hf_hub and necessary:
|
||||||
|
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||||
|
raise RuntimeError(
|
||||||
|
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||||
|
return _has_hf_hub
|
||||||
|
|
||||||
|
|
||||||
|
def download_pretrained_from_hf(
|
||||||
|
model_id: str,
|
||||||
|
filename: str = 'open_clip_pytorch_model.bin',
|
||||||
|
revision=None,
|
||||||
|
cache_dir: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
has_hf_hub(True)
|
||||||
|
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
|
def download_pretrained(
|
||||||
|
cfg: Dict,
|
||||||
|
force_hf_hub: bool = False,
|
||||||
|
cache_dir: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
target = ''
|
||||||
|
if not cfg:
|
||||||
|
return target
|
||||||
|
|
||||||
|
download_url = cfg.get('url', '')
|
||||||
|
download_hf_hub = cfg.get('hf_hub', '')
|
||||||
|
if download_hf_hub and force_hf_hub:
|
||||||
|
# use HF hub even if url exists
|
||||||
|
download_url = ''
|
||||||
|
|
||||||
|
if download_url:
|
||||||
|
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
||||||
|
elif download_hf_hub:
|
||||||
|
has_hf_hub(True)
|
||||||
|
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
||||||
|
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
||||||
|
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
||||||
|
model_id, filename = os.path.split(download_hf_hub)
|
||||||
|
if filename:
|
||||||
|
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
||||||
|
else:
|
||||||
|
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
||||||
|
|
||||||
|
return target
|
||||||
@@ -0,0 +1,243 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import (
|
||||||
|
create_repo,
|
||||||
|
get_hf_file_metadata,
|
||||||
|
hf_hub_download,
|
||||||
|
hf_hub_url,
|
||||||
|
repo_type_and_id_from_hf_id,
|
||||||
|
upload_folder,
|
||||||
|
)
|
||||||
|
from huggingface_hub.utils import EntryNotFoundError
|
||||||
|
_has_hf_hub = True
|
||||||
|
except ImportError:
|
||||||
|
_has_hf_hub = False
|
||||||
|
|
||||||
|
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
|
||||||
|
from .tokenizer import HFTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def save_config_for_hf(
|
||||||
|
model,
|
||||||
|
config_path: str,
|
||||||
|
model_config: Optional[dict]
|
||||||
|
):
|
||||||
|
preprocess_cfg = {
|
||||||
|
'mean': model.visual.image_mean,
|
||||||
|
'std': model.visual.image_std,
|
||||||
|
}
|
||||||
|
hf_config = {
|
||||||
|
'model_cfg': model_config,
|
||||||
|
'preprocess_cfg': preprocess_cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
with config_path.open('w') as f:
|
||||||
|
json.dump(hf_config, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def save_for_hf(
|
||||||
|
model,
|
||||||
|
tokenizer: HFTokenizer,
|
||||||
|
model_config: dict,
|
||||||
|
save_directory: str,
|
||||||
|
weights_filename='open_clip_pytorch_model.bin',
|
||||||
|
config_filename='open_clip_config.json',
|
||||||
|
):
|
||||||
|
save_directory = Path(save_directory)
|
||||||
|
save_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
weights_path = save_directory / weights_filename
|
||||||
|
torch.save(model.state_dict(), weights_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
config_path = save_directory / config_filename
|
||||||
|
save_config_for_hf(model, config_path, model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
|
def push_to_hf_hub(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
model_config: Optional[dict],
|
||||||
|
repo_id: str,
|
||||||
|
commit_message: str = 'Add model',
|
||||||
|
token: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
private: bool = False,
|
||||||
|
create_pr: bool = False,
|
||||||
|
model_card: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
if not isinstance(tokenizer, HFTokenizer):
|
||||||
|
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
|
||||||
|
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
|
||||||
|
|
||||||
|
# Create repo if it doesn't exist yet
|
||||||
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||||
|
|
||||||
|
# Infer complete repo_id from repo_url
|
||||||
|
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||||
|
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||||
|
repo_id = f"{repo_owner}/{repo_name}"
|
||||||
|
|
||||||
|
# Check if README file already exist in repo
|
||||||
|
try:
|
||||||
|
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||||
|
has_readme = True
|
||||||
|
except EntryNotFoundError:
|
||||||
|
has_readme = False
|
||||||
|
|
||||||
|
# Dump model and push to Hub
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
# Save model weights and config.
|
||||||
|
save_for_hf(
|
||||||
|
model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
save_directory=tmpdir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add readme if it does not exist
|
||||||
|
if not has_readme:
|
||||||
|
model_card = model_card or {}
|
||||||
|
model_name = repo_id.split('/')[-1]
|
||||||
|
readme_path = Path(tmpdir) / "README.md"
|
||||||
|
readme_text = generate_readme(model_card, model_name)
|
||||||
|
readme_path.write_text(readme_text)
|
||||||
|
|
||||||
|
# Upload model and return
|
||||||
|
return upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=tmpdir,
|
||||||
|
revision=revision,
|
||||||
|
create_pr=create_pr,
|
||||||
|
commit_message=commit_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def push_pretrained_to_hf_hub(
|
||||||
|
model_name,
|
||||||
|
pretrained: str,
|
||||||
|
repo_id: str,
|
||||||
|
image_mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
image_std: Optional[Tuple[float, ...]] = None,
|
||||||
|
commit_message: str = 'Add model',
|
||||||
|
token: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
private: bool = False,
|
||||||
|
create_pr: bool = False,
|
||||||
|
model_card: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
model, preprocess_eval = create_model_from_pretrained(
|
||||||
|
model_name,
|
||||||
|
pretrained=pretrained,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = get_model_config(model_name)
|
||||||
|
assert model_config
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(model_name)
|
||||||
|
|
||||||
|
push_to_hf_hub(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
repo_id=repo_id,
|
||||||
|
commit_message=commit_message,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
private=private,
|
||||||
|
create_pr=create_pr,
|
||||||
|
model_card=model_card,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_readme(model_card: dict, model_name: str):
|
||||||
|
readme_text = "---\n"
|
||||||
|
readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
|
||||||
|
readme_text += "library_tag: open_clip\n"
|
||||||
|
readme_text += f"license: {model_card.get('license', 'mit')}\n"
|
||||||
|
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||||
|
readme_text += 'datasets:\n'
|
||||||
|
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||||
|
readme_text += "---\n"
|
||||||
|
readme_text += f"# Model card for {model_name}\n"
|
||||||
|
if 'description' in model_card:
|
||||||
|
readme_text += f"\n{model_card['description']}\n"
|
||||||
|
if 'details' in model_card:
|
||||||
|
readme_text += f"\n## Model Details\n"
|
||||||
|
for k, v in model_card['details'].items():
|
||||||
|
if isinstance(v, (list, tuple)):
|
||||||
|
readme_text += f"- **{k}:**\n"
|
||||||
|
for vi in v:
|
||||||
|
readme_text += f" - {vi}\n"
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
readme_text += f"- **{k}:**\n"
|
||||||
|
for ki, vi in v.items():
|
||||||
|
readme_text += f" - {ki}: {vi}\n"
|
||||||
|
else:
|
||||||
|
readme_text += f"- **{k}:** {v}\n"
|
||||||
|
if 'usage' in model_card:
|
||||||
|
readme_text += f"\n## Model Usage\n"
|
||||||
|
readme_text += model_card['usage']
|
||||||
|
readme_text += '\n'
|
||||||
|
|
||||||
|
if 'comparison' in model_card:
|
||||||
|
readme_text += f"\n## Model Comparison\n"
|
||||||
|
readme_text += model_card['comparison']
|
||||||
|
readme_text += '\n'
|
||||||
|
|
||||||
|
if 'citation' in model_card:
|
||||||
|
readme_text += f"\n## Citation\n"
|
||||||
|
if not isinstance(model_card['citation'], (list, tuple)):
|
||||||
|
citations = [model_card['citation']]
|
||||||
|
else:
|
||||||
|
citations = model_card['citation']
|
||||||
|
for c in citations:
|
||||||
|
readme_text += f"```bibtex\n{c}\n```\n"
|
||||||
|
|
||||||
|
return readme_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, help="Name of the model to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained", type=str,
|
||||||
|
help="Use a pretrained CLIP model weights with the specified tag or file path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id", type=str,
|
||||||
|
help="Destination HF Hub repo-id ie 'organization/model_id'.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
|
help='Override default image mean value of dataset')
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
help='Override default image std deviation of of dataset')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
|
||||||
|
|
||||||
|
# FIXME add support to pass model_card json / template from file via cmd line
|
||||||
|
|
||||||
|
push_pretrained_to_hf_hub(
|
||||||
|
args.model,
|
||||||
|
args.pretrained,
|
||||||
|
args.repo_id,
|
||||||
|
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
|
||||||
|
image_std=args.image_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'{args.model} saved.')
|
||||||
127
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
Normal file
127
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
""" timm model adapter
|
||||||
|
|
||||||
|
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import timm
|
||||||
|
from timm.models.layers import Mlp, to_2tuple
|
||||||
|
try:
|
||||||
|
# old timm imports < 0.8.1
|
||||||
|
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
||||||
|
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
||||||
|
except ImportError:
|
||||||
|
# new timm imports >= 0.8.1
|
||||||
|
from timm.layers import RotAttentionPool2d
|
||||||
|
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
||||||
|
except ImportError:
|
||||||
|
timm = None
|
||||||
|
|
||||||
|
from .utils import freeze_batch_norm_2d
|
||||||
|
|
||||||
|
|
||||||
|
class TimmModel(nn.Module):
|
||||||
|
""" timm model adapter
|
||||||
|
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
embed_dim,
|
||||||
|
image_size=224,
|
||||||
|
pool='avg',
|
||||||
|
proj='linear',
|
||||||
|
proj_bias=False,
|
||||||
|
drop=0.,
|
||||||
|
drop_path=None,
|
||||||
|
pretrained=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if timm is None:
|
||||||
|
raise RuntimeError("Please `pip install timm` to use timm models.")
|
||||||
|
|
||||||
|
self.image_size = to_2tuple(image_size)
|
||||||
|
timm_kwargs = {}
|
||||||
|
if drop_path is not None:
|
||||||
|
timm_kwargs['drop_path_rate'] = drop_path
|
||||||
|
self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
|
||||||
|
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
||||||
|
feature_ndim = 1 if not feat_size else 2
|
||||||
|
if pool in ('abs_attn', 'rot_attn'):
|
||||||
|
assert feature_ndim == 2
|
||||||
|
# if attn pooling used, remove both classifier and default pool
|
||||||
|
self.trunk.reset_classifier(0, global_pool='')
|
||||||
|
else:
|
||||||
|
# reset global pool if pool config set, otherwise leave as network default
|
||||||
|
reset_kwargs = dict(global_pool=pool) if pool else {}
|
||||||
|
self.trunk.reset_classifier(0, **reset_kwargs)
|
||||||
|
prev_chs = self.trunk.num_features
|
||||||
|
|
||||||
|
head_layers = OrderedDict()
|
||||||
|
if pool == 'abs_attn':
|
||||||
|
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
||||||
|
prev_chs = embed_dim
|
||||||
|
elif pool == 'rot_attn':
|
||||||
|
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
||||||
|
prev_chs = embed_dim
|
||||||
|
else:
|
||||||
|
assert proj, 'projection layer needed if non-attention pooling is used.'
|
||||||
|
|
||||||
|
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
||||||
|
if proj == 'linear':
|
||||||
|
head_layers['drop'] = nn.Dropout(drop)
|
||||||
|
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
||||||
|
elif proj == 'mlp':
|
||||||
|
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
|
||||||
|
|
||||||
|
self.head = nn.Sequential(head_layers)
|
||||||
|
|
||||||
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
""" lock modules
|
||||||
|
Args:
|
||||||
|
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
||||||
|
"""
|
||||||
|
if not unlocked_groups:
|
||||||
|
# lock full model
|
||||||
|
for param in self.trunk.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if freeze_bn_stats:
|
||||||
|
freeze_batch_norm_2d(self.trunk)
|
||||||
|
else:
|
||||||
|
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
||||||
|
try:
|
||||||
|
# FIXME import here until API stable and in an official release
|
||||||
|
from timm.models.helpers import group_parameters, group_modules
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
||||||
|
matcher = self.trunk.group_matcher()
|
||||||
|
gparams = group_parameters(self.trunk, matcher)
|
||||||
|
max_layer_id = max(gparams.keys())
|
||||||
|
max_layer_id = max_layer_id - unlocked_groups
|
||||||
|
for group_idx in range(max_layer_id + 1):
|
||||||
|
group = gparams[group_idx]
|
||||||
|
for param in group:
|
||||||
|
self.trunk.get_parameter(param).requires_grad = False
|
||||||
|
if freeze_bn_stats:
|
||||||
|
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
||||||
|
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
||||||
|
freeze_batch_norm_2d(self.trunk, gmodules)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
try:
|
||||||
|
self.trunk.set_grad_checkpointing(enable)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.trunk(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
211
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
Normal file
211
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
""" CLIP tokenizer
|
||||||
|
|
||||||
|
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||||
|
"""
|
||||||
|
import gzip
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import ftfy
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# https://stackoverflow.com/q/62691279
|
||||||
|
import os
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def default_bpe():
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||||
|
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8+n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
|
def get_pairs(word):
|
||||||
|
"""Return set of symbol pairs in a word.
|
||||||
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||||
|
"""
|
||||||
|
pairs = set()
|
||||||
|
prev_char = word[0]
|
||||||
|
for char in word[1:]:
|
||||||
|
pairs.add((prev_char, char))
|
||||||
|
prev_char = char
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def basic_clean(text):
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_clean(text):
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTokenizer(object):
|
||||||
|
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||||
|
merges = merges[1:49152-256-2+1]
|
||||||
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
|
vocab = list(bytes_to_unicode().values())
|
||||||
|
vocab = vocab + [v+'</w>' for v in vocab]
|
||||||
|
for merge in merges:
|
||||||
|
vocab.append(''.join(merge))
|
||||||
|
if not special_tokens:
|
||||||
|
special_tokens = ['<start_of_text>', '<end_of_text>']
|
||||||
|
else:
|
||||||
|
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
||||||
|
vocab.extend(special_tokens)
|
||||||
|
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
|
self.cache = {t:t for t in special_tokens}
|
||||||
|
special = "|".join(special_tokens)
|
||||||
|
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||||
|
|
||||||
|
self.vocab_size = len(self.encoder)
|
||||||
|
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
||||||
|
|
||||||
|
def bpe(self, token):
|
||||||
|
if token in self.cache:
|
||||||
|
return self.cache[token]
|
||||||
|
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
return token+'</w>'
|
||||||
|
|
||||||
|
while True:
|
||||||
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
first, second = bigram
|
||||||
|
new_word = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
try:
|
||||||
|
j = word.index(first, i)
|
||||||
|
new_word.extend(word[i:j])
|
||||||
|
i = j
|
||||||
|
except:
|
||||||
|
new_word.extend(word[i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||||
|
new_word.append(first+second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
new_word = tuple(new_word)
|
||||||
|
word = new_word
|
||||||
|
if len(word) == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
word = ' '.join(word)
|
||||||
|
self.cache[token] = word
|
||||||
|
return word
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
bpe_tokens = []
|
||||||
|
text = whitespace_clean(basic_clean(text)).lower()
|
||||||
|
for token in re.findall(self.pat, text):
|
||||||
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Returns the tokenized representation of given input string(s)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts : Union[str, List[str]]
|
||||||
|
An input string or a list of input strings to tokenize
|
||||||
|
context_length : int
|
||||||
|
The context length to use; all CLIP models use 77 as the context length
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
sot_token = self.encoder["<start_of_text>"]
|
||||||
|
eot_token = self.encoder["<end_of_text>"]
|
||||||
|
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
||||||
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||||
|
|
||||||
|
for i, tokens in enumerate(all_tokens):
|
||||||
|
if len(tokens) > context_length:
|
||||||
|
tokens = tokens[:context_length] # Truncate
|
||||||
|
tokens[-1] = eot_token
|
||||||
|
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HFTokenizer:
|
||||||
|
"""HuggingFace tokenizer wrapper"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_name: str):
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
|
||||||
|
def save_pretrained(self, dest):
|
||||||
|
self.tokenizer.save_pretrained(dest)
|
||||||
|
|
||||||
|
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
||||||
|
# same cleaning as for default tokenizer, except lowercasing
|
||||||
|
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
||||||
|
input_ids = self.tokenizer(
|
||||||
|
texts,
|
||||||
|
return_tensors='pt',
|
||||||
|
max_length=context_length,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
).input_ids
|
||||||
|
return input_ids
|
||||||
216
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
Normal file
216
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
|
from functools import partial
|
||||||
|
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
||||||
|
CenterCrop
|
||||||
|
|
||||||
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AugmentationCfg:
|
||||||
|
scale: Tuple[float, float] = (0.9, 1.0)
|
||||||
|
ratio: Optional[Tuple[float, float]] = None
|
||||||
|
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
||||||
|
interpolation: Optional[str] = None
|
||||||
|
re_prob: Optional[float] = None
|
||||||
|
re_count: Optional[int] = None
|
||||||
|
use_timm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeMaxSize(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(max_size, int):
|
||||||
|
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
||||||
|
self.max_size = max_size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
self.fn = min if fn == 'min' else min
|
||||||
|
self.fill = fill
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
height, width = img.shape[1:]
|
||||||
|
else:
|
||||||
|
width, height = img.size
|
||||||
|
scale = self.max_size / float(max(height, width))
|
||||||
|
if scale != 1.0:
|
||||||
|
new_size = tuple(round(dim * scale) for dim in (height, width))
|
||||||
|
img = F.resize(img, new_size, self.interpolation)
|
||||||
|
pad_h = self.max_size - new_size[0]
|
||||||
|
pad_w = self.max_size - new_size[1]
|
||||||
|
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_rgb_or_rgba(image):
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
return image
|
||||||
|
else:
|
||||||
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
# def transform_and_split(merged, transform_fn, normalize_fn):
|
||||||
|
# transformed = transform_fn(merged)
|
||||||
|
# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
|
||||||
|
|
||||||
|
# # crop_img = _convert_to_rgb(crop_img)
|
||||||
|
# crop_img = normalize_fn(ToTensor()(crop_img))
|
||||||
|
# return crop_img, crop_label
|
||||||
|
|
||||||
|
class MaskAwareNormalize(nn.Module):
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
super().__init__()
|
||||||
|
self.normalize = Normalize(mean=mean, std=std)
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
if tensor.shape[0] == 4:
|
||||||
|
return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
|
||||||
|
else:
|
||||||
|
return self.normalize(tensor)
|
||||||
|
|
||||||
|
def image_transform(
|
||||||
|
image_size: int,
|
||||||
|
is_train: bool,
|
||||||
|
mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
std: Optional[Tuple[float, ...]] = None,
|
||||||
|
resize_longest_max: bool = False,
|
||||||
|
fill_color: int = 0,
|
||||||
|
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||||
|
):
|
||||||
|
mean = mean or OPENAI_DATASET_MEAN
|
||||||
|
if not isinstance(mean, (list, tuple)):
|
||||||
|
mean = (mean,) * 3
|
||||||
|
|
||||||
|
std = std or OPENAI_DATASET_STD
|
||||||
|
if not isinstance(std, (list, tuple)):
|
||||||
|
std = (std,) * 3
|
||||||
|
|
||||||
|
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
||||||
|
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
||||||
|
image_size = image_size[0]
|
||||||
|
|
||||||
|
if isinstance(aug_cfg, dict):
|
||||||
|
aug_cfg = AugmentationCfg(**aug_cfg)
|
||||||
|
else:
|
||||||
|
aug_cfg = aug_cfg or AugmentationCfg()
|
||||||
|
normalize = MaskAwareNormalize(mean=mean, std=std)
|
||||||
|
if is_train:
|
||||||
|
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
||||||
|
use_timm = aug_cfg_dict.pop('use_timm', False)
|
||||||
|
if use_timm:
|
||||||
|
assert False, "not tested for augmentation with mask"
|
||||||
|
from timm.data import create_transform # timm can still be optional
|
||||||
|
if isinstance(image_size, (tuple, list)):
|
||||||
|
assert len(image_size) >= 2
|
||||||
|
input_size = (3,) + image_size[-2:]
|
||||||
|
else:
|
||||||
|
input_size = (3, image_size, image_size)
|
||||||
|
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
||||||
|
aug_cfg_dict.setdefault('interpolation', 'random')
|
||||||
|
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
||||||
|
train_transform = create_transform(
|
||||||
|
input_size=input_size,
|
||||||
|
is_training=True,
|
||||||
|
hflip=0.,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
re_mode='pixel',
|
||||||
|
**aug_cfg_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_transform = Compose([
|
||||||
|
_convert_to_rgb_or_rgba,
|
||||||
|
ToTensor(),
|
||||||
|
RandomResizedCrop(
|
||||||
|
image_size,
|
||||||
|
scale=aug_cfg_dict.pop('scale'),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
),
|
||||||
|
normalize,
|
||||||
|
])
|
||||||
|
if aug_cfg_dict:
|
||||||
|
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
||||||
|
return train_transform
|
||||||
|
else:
|
||||||
|
transforms = [
|
||||||
|
_convert_to_rgb_or_rgba,
|
||||||
|
ToTensor(),
|
||||||
|
]
|
||||||
|
if resize_longest_max:
|
||||||
|
transforms.extend([
|
||||||
|
ResizeMaxSize(image_size, fill=fill_color)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
transforms.extend([
|
||||||
|
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
||||||
|
CenterCrop(image_size),
|
||||||
|
])
|
||||||
|
transforms.extend([
|
||||||
|
normalize,
|
||||||
|
])
|
||||||
|
return Compose(transforms)
|
||||||
|
|
||||||
|
|
||||||
|
# def image_transform_region(
|
||||||
|
# image_size: int,
|
||||||
|
# is_train: bool,
|
||||||
|
# mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
# std: Optional[Tuple[float, ...]] = None,
|
||||||
|
# resize_longest_max: bool = False,
|
||||||
|
# fill_color: int = 0,
|
||||||
|
# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||||
|
# ):
|
||||||
|
# mean = mean or OPENAI_DATASET_MEAN
|
||||||
|
# if not isinstance(mean, (list, tuple)):
|
||||||
|
# mean = (mean,) * 3
|
||||||
|
|
||||||
|
# std = std or OPENAI_DATASET_STD
|
||||||
|
# if not isinstance(std, (list, tuple)):
|
||||||
|
# std = (std,) * 3
|
||||||
|
|
||||||
|
# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
||||||
|
# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
||||||
|
# image_size = image_size[0]
|
||||||
|
|
||||||
|
# if isinstance(aug_cfg, dict):
|
||||||
|
# aug_cfg = AugmentationCfg(**aug_cfg)
|
||||||
|
# else:
|
||||||
|
# aug_cfg = aug_cfg or AugmentationCfg()
|
||||||
|
# normalize = Normalize(mean=mean, std=std)
|
||||||
|
# if is_train:
|
||||||
|
# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
||||||
|
|
||||||
|
# transform = Compose([
|
||||||
|
# RandomResizedCrop(
|
||||||
|
# image_size,
|
||||||
|
# scale=aug_cfg_dict.pop('scale'),
|
||||||
|
# interpolation=InterpolationMode.BICUBIC,
|
||||||
|
# ),
|
||||||
|
# ])
|
||||||
|
# train_transform = Compose([
|
||||||
|
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
|
||||||
|
# ])
|
||||||
|
# return train_transform
|
||||||
|
# else:
|
||||||
|
# if resize_longest_max:
|
||||||
|
# transform = [
|
||||||
|
# ResizeMaxSize(image_size, fill=fill_color)
|
||||||
|
# ]
|
||||||
|
# val_transform = Compose([
|
||||||
|
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
||||||
|
# ])
|
||||||
|
# else:
|
||||||
|
# transform = [
|
||||||
|
# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
||||||
|
# CenterCrop(image_size),
|
||||||
|
# ]
|
||||||
|
# val_transform = Compose([
|
||||||
|
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
||||||
|
# ])
|
||||||
|
# return val_transform
|
||||||
727
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
Normal file
727
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
Normal file
@@ -0,0 +1,727 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
import math
|
||||||
|
from typing import Callable, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from .utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormFp32(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
return x.to(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
return x.to(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||||
|
|
||||||
|
|
||||||
|
class PatchDropout(nn.Module):
|
||||||
|
"""
|
||||||
|
https://arxiv.org/abs/2212.00794
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prob, exclude_first_token=True):
|
||||||
|
super().__init__()
|
||||||
|
assert 0 <= prob < 1.
|
||||||
|
self.prob = prob
|
||||||
|
self.exclude_first_token = exclude_first_token # exclude CLS token
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.training or self.prob == 0.:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.exclude_first_token:
|
||||||
|
cls_tokens, x = x[:, :1], x[:, 1:]
|
||||||
|
else:
|
||||||
|
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
||||||
|
|
||||||
|
batch = x.size()[0]
|
||||||
|
num_tokens = x.size()[1]
|
||||||
|
|
||||||
|
batch_indices = torch.arange(batch)
|
||||||
|
batch_indices = batch_indices[..., None]
|
||||||
|
|
||||||
|
keep_prob = 1 - self.prob
|
||||||
|
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
||||||
|
|
||||||
|
rand = torch.randn(batch, num_tokens)
|
||||||
|
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
||||||
|
|
||||||
|
x = x[batch_indices, patch_indices_keep]
|
||||||
|
|
||||||
|
if self.exclude_first_token:
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=True,
|
||||||
|
scaled_cosine=False,
|
||||||
|
scale_heads=False,
|
||||||
|
logit_scale_max=math.log(1. / 0.01),
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scaled_cosine = scaled_cosine
|
||||||
|
self.scale_heads = scale_heads
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.logit_scale_max = logit_scale_max
|
||||||
|
|
||||||
|
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
||||||
|
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
||||||
|
if qkv_bias:
|
||||||
|
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
||||||
|
else:
|
||||||
|
self.in_proj_bias = None
|
||||||
|
|
||||||
|
if self.scaled_cosine:
|
||||||
|
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
||||||
|
else:
|
||||||
|
self.logit_scale = None
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
if self.scale_heads:
|
||||||
|
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
||||||
|
else:
|
||||||
|
self.head_scale = None
|
||||||
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
|
self.out_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
L, N, C = x.shape
|
||||||
|
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
||||||
|
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||||
|
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||||
|
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||||
|
|
||||||
|
if self.logit_scale is not None:
|
||||||
|
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
||||||
|
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
||||||
|
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
||||||
|
attn = attn.view(-1, L, L)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = torch.bmm(q, k.transpose(-1, -2))
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.dtype == torch.bool:
|
||||||
|
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
||||||
|
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
||||||
|
attn_mask = new_attn_mask
|
||||||
|
attn += attn_mask
|
||||||
|
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = torch.bmm(attn, v)
|
||||||
|
if self.head_scale is not None:
|
||||||
|
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
||||||
|
x = x.view(-1, L, C)
|
||||||
|
x = x.transpose(0, 1).reshape(L, N, C)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
x = self.out_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionalPooler(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
context_dim: int,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_queries: int = 256,
|
||||||
|
norm_layer: Callable = LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
|
||||||
|
self.ln_q = norm_layer(d_model)
|
||||||
|
self.ln_k = norm_layer(context_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
||||||
|
N = x.shape[1]
|
||||||
|
q = self.ln_q(self.query)
|
||||||
|
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
|
||||||
|
return out.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
def _repeat(self, query, N: int):
|
||||||
|
return query.unsqueeze(1).repeat(1, N, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_head: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
is_cross_attention: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = norm_layer(d_model)
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||||
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
if is_cross_attention:
|
||||||
|
self.ln_1_kv = norm_layer(d_model)
|
||||||
|
|
||||||
|
self.ln_2 = norm_layer(d_model)
|
||||||
|
mlp_width = int(d_model * mlp_ratio)
|
||||||
|
self.mlp = nn.Sequential(OrderedDict([
|
||||||
|
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||||
|
("gelu", act_layer()),
|
||||||
|
("c_proj", nn.Linear(mlp_width, d_model))
|
||||||
|
]))
|
||||||
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
self,
|
||||||
|
q_x: torch.Tensor,
|
||||||
|
k_x: Optional[torch.Tensor] = None,
|
||||||
|
v_x: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
k_x = k_x if k_x is not None else q_x
|
||||||
|
v_x = v_x if v_x is not None else q_x
|
||||||
|
|
||||||
|
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
||||||
|
return self.attn(
|
||||||
|
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q_x: torch.Tensor,
|
||||||
|
k_x: Optional[torch.Tensor] = None,
|
||||||
|
v_x: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
||||||
|
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
||||||
|
|
||||||
|
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
||||||
|
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CustomResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_head: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
scale_cosine_attn: bool = False,
|
||||||
|
scale_heads: bool = False,
|
||||||
|
scale_attn: bool = False,
|
||||||
|
scale_fc: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = norm_layer(d_model)
|
||||||
|
self.attn = Attention(
|
||||||
|
d_model, n_head,
|
||||||
|
scaled_cosine=scale_cosine_attn,
|
||||||
|
scale_heads=scale_heads,
|
||||||
|
)
|
||||||
|
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
||||||
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
|
||||||
|
self.ln_2 = norm_layer(d_model)
|
||||||
|
mlp_width = int(d_model * mlp_ratio)
|
||||||
|
self.mlp = nn.Sequential(OrderedDict([
|
||||||
|
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||||
|
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
||||||
|
("gelu", act_layer()),
|
||||||
|
("c_proj", nn.Linear(mlp_width, d_model))
|
||||||
|
]))
|
||||||
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
||||||
|
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList([
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
|
for _ in range(layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_cast_dtype(self) -> torch.dtype:
|
||||||
|
return self.resblocks[0].mlp.c_fc.weight.dtype
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
for r in self.resblocks:
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||||
|
x = checkpoint(r, x, None, None, attn_mask)
|
||||||
|
else:
|
||||||
|
x = r(x, attn_mask=attn_mask)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
output_tokens: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
global_average_pool: bool = False,
|
||||||
|
attentional_pool: bool = False,
|
||||||
|
n_queries: int = 256,
|
||||||
|
attn_pooler_heads: int = 8,
|
||||||
|
output_dim: int = 512,
|
||||||
|
patch_dropout: float = 0.,
|
||||||
|
input_patchnorm: bool = False,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
output_tokens: bool = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
image_height, image_width = self.image_size = to_2tuple(image_size)
|
||||||
|
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
||||||
|
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
|
||||||
|
self.input_patchnorm = input_patchnorm
|
||||||
|
|
||||||
|
if input_patchnorm:
|
||||||
|
patch_input_dim = patch_height * patch_width * 3
|
||||||
|
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
|
||||||
|
self.conv1 = nn.Linear(patch_input_dim, width)
|
||||||
|
else:
|
||||||
|
self.patchnorm_pre_ln = nn.Identity()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||||
|
|
||||||
|
# class embeddings and positional embeddings
|
||||||
|
scale = width ** -0.5
|
||||||
|
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||||
|
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
||||||
|
|
||||||
|
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
||||||
|
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
||||||
|
|
||||||
|
self.ln_pre = norm_layer(width)
|
||||||
|
self.transformer = Transformer(
|
||||||
|
width,
|
||||||
|
layers,
|
||||||
|
heads,
|
||||||
|
mlp_ratio,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.global_average_pool = global_average_pool
|
||||||
|
if attentional_pool:
|
||||||
|
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
|
||||||
|
self.ln_post = norm_layer(output_dim)
|
||||||
|
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
|
||||||
|
else:
|
||||||
|
self.attn_pool = None
|
||||||
|
self.ln_post = norm_layer(width)
|
||||||
|
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||||
|
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if unlocked_groups != 0:
|
||||||
|
groups = [
|
||||||
|
[
|
||||||
|
self.conv1,
|
||||||
|
self.class_embedding,
|
||||||
|
self.positional_embedding,
|
||||||
|
self.ln_pre,
|
||||||
|
],
|
||||||
|
*self.transformer.resblocks[:-1],
|
||||||
|
[
|
||||||
|
self.transformer.resblocks[-1],
|
||||||
|
self.ln_post,
|
||||||
|
],
|
||||||
|
self.proj,
|
||||||
|
]
|
||||||
|
|
||||||
|
def _unlock(x):
|
||||||
|
if isinstance(x, Sequence):
|
||||||
|
for g in x:
|
||||||
|
_unlock(g)
|
||||||
|
else:
|
||||||
|
if isinstance(x, torch.nn.Parameter):
|
||||||
|
x.requires_grad = True
|
||||||
|
else:
|
||||||
|
for p in x.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
_unlock(groups[-unlocked_groups:])
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
|
||||||
|
# TODO experiment if default PyTorch init, below, or alternate init is best.
|
||||||
|
|
||||||
|
# nn.init.normal_(self.class_embedding, std=self.scale)
|
||||||
|
# nn.init.normal_(self.positional_embedding, std=self.scale)
|
||||||
|
#
|
||||||
|
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||||
|
# attn_std = self.transformer.width ** -0.5
|
||||||
|
# fc_std = (2 * self.transformer.width) ** -0.5
|
||||||
|
# for block in self.transformer.resblocks:
|
||||||
|
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
#
|
||||||
|
# if self.text_projection is not None:
|
||||||
|
# nn.init.normal_(self.text_projection, std=self.scale)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.transformer.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.global_average_pool:
|
||||||
|
return x.mean(dim=1), x
|
||||||
|
else:
|
||||||
|
return x[:, 0], x[:, 1:]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, skip_pool: bool = False):
|
||||||
|
|
||||||
|
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
||||||
|
if self.input_patchnorm:
|
||||||
|
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
|
||||||
|
x = x.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
|
||||||
|
x = self.patchnorm_pre_ln(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
else:
|
||||||
|
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
|
||||||
|
# class embeddings and positional embeddings
|
||||||
|
x = torch.cat(
|
||||||
|
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||||
|
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||||
|
x = x + self.positional_embedding.to(x.dtype)
|
||||||
|
|
||||||
|
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
||||||
|
x = self.patch_dropout(x)
|
||||||
|
x = self.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
if skip_pool:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.attn_pool is not None:
|
||||||
|
x = self.attn_pool(x)
|
||||||
|
x = self.ln_post(x)
|
||||||
|
pooled, tokens = self._global_pool(x)
|
||||||
|
else:
|
||||||
|
pooled, tokens = self._global_pool(x)
|
||||||
|
pooled = self.ln_post(pooled)
|
||||||
|
|
||||||
|
if self.proj is not None:
|
||||||
|
pooled = pooled @ self.proj
|
||||||
|
|
||||||
|
if self.output_tokens:
|
||||||
|
return pooled, tokens
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class TextTransformer(nn.Module):
|
||||||
|
output_tokens: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_length: int = 77,
|
||||||
|
vocab_size: int = 49408,
|
||||||
|
width: int = 512,
|
||||||
|
heads: int = 8,
|
||||||
|
layers: int = 12,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
output_dim: int = 512,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
embed_cls: bool = False,
|
||||||
|
pad_id: int = 0,
|
||||||
|
output_tokens: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
self.num_pos = self.context_length = context_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.width = width
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.heads = heads
|
||||||
|
self.pad_id = pad_id
|
||||||
|
|
||||||
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||||
|
|
||||||
|
if embed_cls:
|
||||||
|
self.cls_emb = nn.Parameter(torch.empty(width))
|
||||||
|
self.num_pos += 1
|
||||||
|
else:
|
||||||
|
self.cls_emb = None
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, width)
|
||||||
|
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
||||||
|
self.transformer = Transformer(
|
||||||
|
width=width,
|
||||||
|
layers=layers,
|
||||||
|
heads=heads,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
self.ln_final = norm_layer(width)
|
||||||
|
|
||||||
|
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
||||||
|
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||||
|
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||||
|
if self.cls_emb is not None:
|
||||||
|
nn.init.normal_(self.cls_emb, std=0.01)
|
||||||
|
|
||||||
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||||
|
attn_std = self.transformer.width ** -0.5
|
||||||
|
fc_std = (2 * self.transformer.width) ** -0.5
|
||||||
|
for block in self.transformer.resblocks:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.transformer.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def build_attention_mask(self):
|
||||||
|
# lazily create causal attention mask, with full attention between the tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(self.num_pos, self.num_pos)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def build_cls_mask(self, text, cast_dtype: torch.dtype):
|
||||||
|
cls_mask = (text != self.pad_id).unsqueeze(1)
|
||||||
|
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
|
||||||
|
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
|
||||||
|
additive_mask.fill_(0)
|
||||||
|
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
||||||
|
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
||||||
|
return additive_mask
|
||||||
|
|
||||||
|
def _repeat(self, t, N: int):
|
||||||
|
return t.reshape(1, 1, -1).repeat(N, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
cast_dtype = self.transformer.get_cast_dtype()
|
||||||
|
seq_len = text.shape[1]
|
||||||
|
|
||||||
|
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||||
|
attn_mask = self.attn_mask
|
||||||
|
if self.cls_emb is not None:
|
||||||
|
seq_len += 1
|
||||||
|
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
|
||||||
|
cls_mask = self.build_cls_mask(text, cast_dtype)
|
||||||
|
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
|
||||||
|
|
||||||
|
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x, attn_mask=attn_mask)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
if self.cls_emb is not None:
|
||||||
|
pooled, tokens = x[:, -1], x[:, :-1]
|
||||||
|
pooled = self.ln_final(pooled)
|
||||||
|
else:
|
||||||
|
x = self.ln_final(x)
|
||||||
|
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
pooled = pooled @ self.text_projection
|
||||||
|
|
||||||
|
if self.output_tokens:
|
||||||
|
return pooled, tokens
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalTransformer(Transformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
context_length: int = 77,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
output_dim: int = 512,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
width=width,
|
||||||
|
layers=layers,
|
||||||
|
heads=heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
self.context_length = context_length
|
||||||
|
self.cross_attn = nn.ModuleList([
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
width,
|
||||||
|
heads,
|
||||||
|
mlp_ratio,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
is_cross_attention=True,
|
||||||
|
)
|
||||||
|
for _ in range(layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
||||||
|
|
||||||
|
self.ln_final = norm_layer(width)
|
||||||
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||||
|
attn_std = self.transformer.width ** -0.5
|
||||||
|
fc_std = (2 * self.transformer.width) ** -0.5
|
||||||
|
for block in self.transformer.resblocks:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
for block in self.transformer.cross_attn:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||||
|
|
||||||
|
def build_attention_mask(self):
|
||||||
|
# lazily create causal attention mask, with full attention between the tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(self.context_length, self.context_length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, image_embs, text_embs):
|
||||||
|
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
|
||||||
|
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
|
||||||
|
seq_len = text_embs.shape[0]
|
||||||
|
|
||||||
|
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||||
|
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
|
||||||
|
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
|
||||||
|
else:
|
||||||
|
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
|
||||||
|
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
|
||||||
|
|
||||||
|
x = text_embs.permute(1, 0, 2) # LND -> NLD
|
||||||
|
x = self.ln_final(x)
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
x = x @ self.text_projection
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
60
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
Normal file
60
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from itertools import repeat
|
||||||
|
import collections.abc
|
||||||
|
|
||||||
|
from torch import nn as nn
|
||||||
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
||||||
|
"""
|
||||||
|
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
||||||
|
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
||||||
|
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): Any PyTorch module.
|
||||||
|
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
||||||
|
name (str): Full module name (prefix)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: Resulting module
|
||||||
|
|
||||||
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
||||||
|
"""
|
||||||
|
res = module
|
||||||
|
is_match = True
|
||||||
|
if module_match:
|
||||||
|
is_match = name in module_match
|
||||||
|
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
||||||
|
res = FrozenBatchNorm2d(module.num_features)
|
||||||
|
res.num_features = module.num_features
|
||||||
|
res.affine = module.affine
|
||||||
|
if module.affine:
|
||||||
|
res.weight.data = module.weight.data.clone().detach()
|
||||||
|
res.bias.data = module.bias.data.clone().detach()
|
||||||
|
res.running_mean.data = module.running_mean.data
|
||||||
|
res.running_var.data = module.running_var.data
|
||||||
|
res.eps = module.eps
|
||||||
|
else:
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
full_child_name = '.'.join([name, child_name]) if name else child_name
|
||||||
|
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
||||||
|
if new_child is not child:
|
||||||
|
res.add_module(child_name, new_child)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_1tuple = _ntuple(1)
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
to_3tuple = _ntuple(3)
|
||||||
|
to_4tuple = _ntuple(4)
|
||||||
|
to_ntuple = lambda n, x: _ntuple(n)(x)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '2.16.0'
|
||||||
112
diffsynth/extensions/ImageQualityMetric/pickscore.py
Normal file
112
diffsynth/extensions/ImageQualityMetric/pickscore.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, AutoModel
|
||||||
|
from typing import List, Union
|
||||||
|
import os
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class PickScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the Selector with a processor and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
|
"""
|
||||||
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
|
processor_name_or_path = path.get("clip")
|
||||||
|
model_pretrained_name_or_path = path.get("pickscore")
|
||||||
|
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
||||||
|
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
|
||||||
|
"""Calculate the score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
softmax (bool): Whether to apply softmax to the scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The score for the image.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Prepare text inputs
|
||||||
|
text_inputs = self.processor(
|
||||||
|
text=prompt,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
# Embed images and text
|
||||||
|
image_embs = self.model.get_image_features(pixel_values=image)
|
||||||
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||||
|
text_embs = self.model.get_text_features(**text_inputs)
|
||||||
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Compute score
|
||||||
|
score = (text_embs @ image_embs.T)[0]
|
||||||
|
if softmax:
|
||||||
|
# Apply logit scale and softmax
|
||||||
|
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
|
||||||
|
|
||||||
|
return score.cpu().item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
softmax (bool): Whether to apply softmax to the scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
pil_image = Image.open(images)
|
||||||
|
else:
|
||||||
|
pil_image = images
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_image in images:
|
||||||
|
if isinstance(one_image, str):
|
||||||
|
pil_image = Image.open(one_image)
|
||||||
|
elif isinstance(one_image, Image.Image):
|
||||||
|
pil_image = one_image
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .models import *
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .base_model import *
|
||||||
|
from .clip_model import *
|
||||||
|
from .cross_modeling import *
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelConfig:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers import CLIPModel as HFCLIPModel
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from torch import nn, einsum
|
||||||
|
|
||||||
|
from .base_model import BaseModelConfig
|
||||||
|
|
||||||
|
from transformers import CLIPConfig
|
||||||
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .cross_modeling import Cross_model
|
||||||
|
|
||||||
|
import json, os
|
||||||
|
|
||||||
|
class XCLIPModel(HFCLIPModel):
|
||||||
|
def __init__(self, config: CLIPConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def get_text_features(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
text_outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pooled_output = text_outputs[1]
|
||||||
|
# text_features = self.text_projection(pooled_output)
|
||||||
|
last_hidden_state = text_outputs[0]
|
||||||
|
text_features = self.text_projection(last_hidden_state)
|
||||||
|
|
||||||
|
pooled_output = text_outputs[1]
|
||||||
|
text_features_EOS = self.text_projection(pooled_output)
|
||||||
|
|
||||||
|
|
||||||
|
# del last_hidden_state, text_outputs
|
||||||
|
# gc.collect()
|
||||||
|
|
||||||
|
return text_features, text_features_EOS
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pooled_output = vision_outputs[1] # pooled_output
|
||||||
|
# image_features = self.visual_projection(pooled_output)
|
||||||
|
last_hidden_state = vision_outputs[0]
|
||||||
|
image_features = self.visual_projection(last_hidden_state)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipModelConfig(BaseModelConfig):
|
||||||
|
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
|
||||||
|
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPModel(nn.Module):
|
||||||
|
def __init__(self, ckpt, config_file=False):
|
||||||
|
super().__init__()
|
||||||
|
if config_file is None:
|
||||||
|
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||||
|
else:
|
||||||
|
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = CLIPConfig(**config)
|
||||||
|
self.model = XCLIPModel._from_config(config)
|
||||||
|
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
||||||
|
|
||||||
|
def get_text_features(self, *args, **kwargs):
|
||||||
|
return self.model.get_text_features(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_image_features(self, *args, **kwargs):
|
||||||
|
return self.model.get_image_features(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
||||||
|
outputs = ()
|
||||||
|
|
||||||
|
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
||||||
|
outputs += text_EOS,
|
||||||
|
|
||||||
|
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
||||||
|
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
||||||
|
|
||||||
|
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||||
|
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||||
|
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||||
|
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
||||||
|
|
||||||
|
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
||||||
|
bc = int(image_f.shape[0]/2)
|
||||||
|
|
||||||
|
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
||||||
|
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
||||||
|
outputs += sim0[:,0,:],
|
||||||
|
outputs += sim1[:,0,:],
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logit_scale(self):
|
||||||
|
return self.model.logit_scale
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
self.model.save_pretrained(path)
|
||||||
|
|
||||||
@@ -0,0 +1,292 @@
|
|||||||
|
import torch
|
||||||
|
from torch import einsum, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
# they use layernorm without bias, something that pytorch does not offer
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.register_buffer("bias", torch.zeros(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
||||||
|
|
||||||
|
# residual
|
||||||
|
|
||||||
|
|
||||||
|
class Residual(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return self.fn(x, *args, **kwargs) + x
|
||||||
|
|
||||||
|
|
||||||
|
# rotary positional embedding
|
||||||
|
# https://arxiv.org/abs/2104.09864
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
def forward(self, max_seq_len, *, device):
|
||||||
|
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
||||||
|
return torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
|
x1, x2 = x.unbind(dim=-2)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(pos, t):
|
||||||
|
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
||||||
|
|
||||||
|
|
||||||
|
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
||||||
|
# https://arxiv.org/abs/2002.05202
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return F.silu(gate) * x
|
||||||
|
|
||||||
|
|
||||||
|
# parallel attention and feedforward with residual
|
||||||
|
# discovered by Wang et al + EleutherAI from GPT-J fame
|
||||||
|
|
||||||
|
class ParallelTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNorm(dim)
|
||||||
|
|
||||||
|
attn_inner_dim = dim_head * heads
|
||||||
|
ff_inner_dim = dim * ff_mult
|
||||||
|
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
||||||
|
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
self.ff_out = nn.Sequential(
|
||||||
|
SwiGLU(),
|
||||||
|
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("pos_emb", None, persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_embedding(self, n, device):
|
||||||
|
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
||||||
|
return self.pos_emb[:n]
|
||||||
|
|
||||||
|
pos_emb = self.rotary_emb(n, device=device)
|
||||||
|
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
n, device, h = x.shape[1], x.device, self.heads
|
||||||
|
|
||||||
|
# pre layernorm
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# attention queries, keys, values, and feedforward inner
|
||||||
|
|
||||||
|
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||||
|
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||||
|
# https://arxiv.org/abs/1911.02150
|
||||||
|
|
||||||
|
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
positions = self.get_rotary_embedding(n, device)
|
||||||
|
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# similarity
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
|
||||||
|
# extra attention mask - for masking out attention from text CLS token to padding
|
||||||
|
|
||||||
|
if exists(attn_mask):
|
||||||
|
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
||||||
|
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
return self.attn_out(out) + self.ff_out(ff)
|
||||||
|
|
||||||
|
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
context_dim=None,
|
||||||
|
dim_head=64,
|
||||||
|
heads=12,
|
||||||
|
parallel_ff=False,
|
||||||
|
ff_mult=4,
|
||||||
|
norm_context=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = heads * dim_head
|
||||||
|
context_dim = default(context_dim, dim)
|
||||||
|
|
||||||
|
self.norm = LayerNorm(dim)
|
||||||
|
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
# whether to have parallel feedforward
|
||||||
|
|
||||||
|
ff_inner_dim = ff_mult * dim
|
||||||
|
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
||||||
|
SwiGLU(),
|
||||||
|
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||||
|
) if parallel_ff else None
|
||||||
|
|
||||||
|
def forward(self, x, context, mask):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pre-layernorm, for queries and context
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
context = self.context_norm(context)
|
||||||
|
|
||||||
|
# get queries
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# get key / values
|
||||||
|
|
||||||
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# query / key similarity
|
||||||
|
|
||||||
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
||||||
|
sim = sim + mask # context mask
|
||||||
|
sim = sim - sim.amax(dim=-1, keepdim=True)
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
|
# merge and combine heads
|
||||||
|
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
# add parallel feedforward (for multimodal layers)
|
||||||
|
|
||||||
|
if exists(self.ff):
|
||||||
|
out = out + self.ff(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Cross_model(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=512,
|
||||||
|
layer_num=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
|
||||||
|
for ind in range(layer_num):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
||||||
|
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query_tokens,
|
||||||
|
context_tokens,
|
||||||
|
mask
|
||||||
|
):
|
||||||
|
|
||||||
|
for cross_attn, self_attn_ff in self.layers:
|
||||||
|
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
||||||
|
query_tokens = self_attn_ff(query_tokens)
|
||||||
|
|
||||||
|
return query_tokens
|
||||||
@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class IFNet(nn.Module):
|
class IFNet(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
super(IFNet, self).__init__()
|
super(IFNet, self).__init__()
|
||||||
self.block0 = IFBlock(7+4, c=90)
|
self.block0 = IFBlock(7+4, c=90)
|
||||||
self.block1 = IFBlock(7+4, c=90)
|
self.block1 = IFBlock(7+4, c=90)
|
||||||
@@ -99,7 +99,8 @@ class IFNet(nn.Module):
|
|||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||||
return flow_list, mask_list[2], merged
|
return flow_list, mask_list[2], merged
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return IFNetStateDictConverter()
|
return IFNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -112,7 +113,7 @@ class IFNetStateDictConverter:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
return self.from_diffusers(state_dict)
|
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
||||||
|
|
||||||
|
|
||||||
class RIFEInterpolater:
|
class RIFEInterpolater:
|
||||||
@@ -124,7 +125,7 @@ class RIFEInterpolater:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager):
|
def from_model_manager(model_manager):
|
||||||
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image):
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
@@ -202,7 +203,7 @@ class RIFESmoother(RIFEInterpolater):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager):
|
def from_model_manager(model_manager):
|
||||||
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||||
|
|
||||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
||||||
output_tensor = []
|
output_tensor = []
|
||||||
|
|||||||
0
diffsynth/extensions/__init__.py
Normal file
0
diffsynth/extensions/__init__.py
Normal file
45
diffsynth/lora/__init__.py
Normal file
45
diffsynth/lora/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralLoRALoader:
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_dict(self, lora_state_dict):
|
||||||
|
lora_name_dict = {}
|
||||||
|
for key in lora_state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
keys = key.split(".")
|
||||||
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
|
keys.pop(-1)
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
updated_num = 0
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in lora_name_dict:
|
||||||
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||||
|
module.load_state_dict(state_dict)
|
||||||
|
updated_num += 1
|
||||||
|
print(f"{updated_num} tensors are updated by LoRA.")
|
||||||
13
diffsynth/lora/flux_lora.py
Normal file
13
diffsynth/lora/flux_lora.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.lora import GeneralLoRALoader
|
||||||
|
from diffsynth.models.lora import FluxLoRAFromCivitai
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRALoader(GeneralLoRALoader):
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||||
|
self.loader = FluxLoRAFromCivitai()
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
|
||||||
|
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)
|
||||||
@@ -1,482 +1 @@
|
|||||||
import torch, os
|
from .model_manager import *
|
||||||
from safetensors import safe_open
|
|
||||||
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sd_vae_encoder import SDVAEEncoder
|
|
||||||
from .sd_vae_decoder import SDVAEDecoder
|
|
||||||
from .sd_lora import SDLoRA
|
|
||||||
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from .sd_controlnet import SDControlNet
|
|
||||||
|
|
||||||
from .sd_motion import SDMotionModel
|
|
||||||
from .sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from .svd_image_encoder import SVDImageEncoder
|
|
||||||
from .svd_unet import SVDUNet
|
|
||||||
from .svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from .svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
|
||||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.device = device
|
|
||||||
self.model = {}
|
|
||||||
self.model_path = {}
|
|
||||||
self.textual_inversion_dict = {}
|
|
||||||
|
|
||||||
def is_stable_video_diffusion(self, state_dict):
|
|
||||||
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_RIFE(self, state_dict):
|
|
||||||
param_name = "block_tea.convblock3.0.1.weight"
|
|
||||||
return param_name in state_dict or ("module." + param_name) in state_dict
|
|
||||||
|
|
||||||
def is_beautiful_prompt(self, state_dict):
|
|
||||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_stabe_diffusion_xl(self, state_dict):
|
|
||||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_stable_diffusion(self, state_dict):
|
|
||||||
if self.is_stabe_diffusion_xl(state_dict):
|
|
||||||
return False
|
|
||||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_controlnet(self, state_dict):
|
|
||||||
param_name = "control_model.time_embed.0.weight"
|
|
||||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
|
||||||
return param_name in state_dict or param_name_2 in state_dict
|
|
||||||
|
|
||||||
def is_animatediff(self, state_dict):
|
|
||||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_animatediff_xl(self, state_dict):
|
|
||||||
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_sd_lora(self, state_dict):
|
|
||||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_translator(self, state_dict):
|
|
||||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
|
||||||
return param_name in state_dict and len(state_dict) == 254
|
|
||||||
|
|
||||||
def is_ipadapter(self, state_dict):
|
|
||||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
|
||||||
|
|
||||||
def is_ipadapter_image_encoder(self, state_dict):
|
|
||||||
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
|
||||||
return param_name in state_dict and len(state_dict) == 521
|
|
||||||
|
|
||||||
def is_ipadapter_xl(self, state_dict):
|
|
||||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
|
||||||
|
|
||||||
def is_ipadapter_xl_image_encoder(self, state_dict):
|
|
||||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
|
||||||
return param_name in state_dict and len(state_dict) == 777
|
|
||||||
|
|
||||||
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
|
||||||
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
|
||||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_hunyuan_dit(self, state_dict):
|
|
||||||
param_name = "final_layer.adaLN_modulation.1.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_diffusers_vae(self, state_dict):
|
|
||||||
param_name = "quant_conv.weight"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
|
||||||
param_name = "blocks.185.positional_embedding.embeddings"
|
|
||||||
return param_name in state_dict
|
|
||||||
|
|
||||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
|
||||||
component_dict = {
|
|
||||||
"image_encoder": SVDImageEncoder,
|
|
||||||
"unet": SVDUNet,
|
|
||||||
"vae_decoder": SVDVAEDecoder,
|
|
||||||
"vae_encoder": SVDVAEEncoder,
|
|
||||||
}
|
|
||||||
if components is None:
|
|
||||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
|
||||||
for component in components:
|
|
||||||
if component == "unet":
|
|
||||||
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
|
||||||
else:
|
|
||||||
self.model[component] = component_dict[component]()
|
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
|
||||||
self.model[component].to(self.torch_dtype).to(self.device)
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
|
||||||
component_dict = {
|
|
||||||
"text_encoder": SDTextEncoder,
|
|
||||||
"unet": SDUNet,
|
|
||||||
"vae_decoder": SDVAEDecoder,
|
|
||||||
"vae_encoder": SDVAEEncoder,
|
|
||||||
"refiner": SDXLUNet,
|
|
||||||
}
|
|
||||||
if components is None:
|
|
||||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
|
||||||
for component in components:
|
|
||||||
if component == "text_encoder":
|
|
||||||
# Add additional token embeddings to text encoder
|
|
||||||
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
|
||||||
for keyword in self.textual_inversion_dict:
|
|
||||||
_, embeddings = self.textual_inversion_dict[keyword]
|
|
||||||
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
|
||||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
|
||||||
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
|
||||||
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
|
||||||
self.model[component].to(self.torch_dtype).to(self.device)
|
|
||||||
else:
|
|
||||||
self.model[component] = component_dict[component]()
|
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
|
||||||
self.model[component].to(self.torch_dtype).to(self.device)
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
|
||||||
component_dict = {
|
|
||||||
"text_encoder": SDXLTextEncoder,
|
|
||||||
"text_encoder_2": SDXLTextEncoder2,
|
|
||||||
"unet": SDXLUNet,
|
|
||||||
"vae_decoder": SDXLVAEDecoder,
|
|
||||||
"vae_encoder": SDXLVAEEncoder,
|
|
||||||
}
|
|
||||||
if components is None:
|
|
||||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
|
||||||
for component in components:
|
|
||||||
self.model[component] = component_dict[component]()
|
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
|
||||||
if component in ["vae_decoder", "vae_encoder"]:
|
|
||||||
# These two model will output nan when float16 is enabled.
|
|
||||||
# The precision problem happens in the last three resnet blocks.
|
|
||||||
# I do not know how to solve this problem.
|
|
||||||
self.model[component].to(torch.float32).to(self.device)
|
|
||||||
else:
|
|
||||||
self.model[component].to(self.torch_dtype).to(self.device)
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_controlnet(self, state_dict, file_path=""):
|
|
||||||
component = "controlnet"
|
|
||||||
if component not in self.model:
|
|
||||||
self.model[component] = []
|
|
||||||
self.model_path[component] = []
|
|
||||||
model = SDControlNet()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component].append(model)
|
|
||||||
self.model_path[component].append(file_path)
|
|
||||||
|
|
||||||
def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
|
|
||||||
component = "motion_modules"
|
|
||||||
model = SDMotionModel(add_positional_conv=add_positional_conv)
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_animatediff_xl(self, state_dict, file_path=""):
|
|
||||||
component = "motion_modules_xl"
|
|
||||||
model = SDXLMotionModel()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
|
||||||
component = "beautiful_prompt"
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
model_folder = os.path.dirname(file_path)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
|
||||||
).to(self.device).eval()
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_RIFE(self, state_dict, file_path=""):
|
|
||||||
component = "RIFE"
|
|
||||||
from ..extensions.RIFE import IFNet
|
|
||||||
model = IFNet().eval()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(torch.float32).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_sd_lora(self, state_dict, alpha):
|
|
||||||
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
|
||||||
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
|
||||||
|
|
||||||
def load_translator(self, state_dict, file_path=""):
|
|
||||||
# This model is lightweight, we do not place it on GPU.
|
|
||||||
component = "translator"
|
|
||||||
from transformers import AutoModelForSeq2SeqLM
|
|
||||||
model_folder = os.path.dirname(file_path)
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_ipadapter(self, state_dict, file_path=""):
|
|
||||||
component = "ipadapter"
|
|
||||||
model = SDIpAdapter()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
|
||||||
component = "ipadapter_image_encoder"
|
|
||||||
model = IpAdapterCLIPImageEmbedder()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_ipadapter_xl(self, state_dict, file_path=""):
|
|
||||||
component = "ipadapter_xl"
|
|
||||||
model = SDXLIpAdapter()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
|
||||||
component = "ipadapter_xl_image_encoder"
|
|
||||||
model = IpAdapterXLCLIPImageEmbedder()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
|
||||||
component = "hunyuan_dit_clip_text_encoder"
|
|
||||||
model = HunyuanDiTCLIPTextEncoder()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
|
||||||
component = "hunyuan_dit_t5_text_encoder"
|
|
||||||
model = HunyuanDiTT5TextEncoder()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_hunyuan_dit(self, state_dict, file_path=""):
|
|
||||||
component = "hunyuan_dit"
|
|
||||||
model = HunyuanDiT()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_diffusers_vae(self, state_dict, file_path=""):
|
|
||||||
# TODO: detect SD and SDXL
|
|
||||||
component = "vae_encoder"
|
|
||||||
model = SDXLVAEEncoder()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
component = "vae_decoder"
|
|
||||||
model = SDXLVAEDecoder()
|
|
||||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
|
||||||
model.to(self.torch_dtype).to(self.device)
|
|
||||||
self.model[component] = model
|
|
||||||
self.model_path[component] = file_path
|
|
||||||
|
|
||||||
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
|
||||||
unet_state_dict = self.model["unet"].state_dict()
|
|
||||||
self.model["unet"].to("cpu")
|
|
||||||
del self.model["unet"]
|
|
||||||
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
|
||||||
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
|
||||||
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
|
||||||
self.model["unet"].load_state_dict(state_dict, strict=False)
|
|
||||||
self.model["unet"].to(self.torch_dtype).to(self.device)
|
|
||||||
|
|
||||||
def search_for_embeddings(self, state_dict):
|
|
||||||
embeddings = []
|
|
||||||
for k in state_dict:
|
|
||||||
if isinstance(state_dict[k], torch.Tensor):
|
|
||||||
embeddings.append(state_dict[k])
|
|
||||||
elif isinstance(state_dict[k], dict):
|
|
||||||
embeddings += self.search_for_embeddings(state_dict[k])
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def load_textual_inversions(self, folder):
|
|
||||||
# Store additional tokens here
|
|
||||||
self.textual_inversion_dict = {}
|
|
||||||
|
|
||||||
# Load every textual inversion file
|
|
||||||
for file_name in os.listdir(folder):
|
|
||||||
if file_name.endswith(".txt"):
|
|
||||||
continue
|
|
||||||
keyword = os.path.splitext(file_name)[0]
|
|
||||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
|
||||||
|
|
||||||
# Search for embeddings
|
|
||||||
for embeddings in self.search_for_embeddings(state_dict):
|
|
||||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
|
||||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
|
||||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
|
||||||
break
|
|
||||||
|
|
||||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
|
||||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
|
||||||
if self.is_stable_video_diffusion(state_dict):
|
|
||||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
|
||||||
elif self.is_animatediff(state_dict):
|
|
||||||
self.load_animatediff(state_dict, file_path=file_path)
|
|
||||||
elif self.is_animatediff_xl(state_dict):
|
|
||||||
self.load_animatediff_xl(state_dict, file_path=file_path)
|
|
||||||
elif self.is_controlnet(state_dict):
|
|
||||||
self.load_controlnet(state_dict, file_path=file_path)
|
|
||||||
elif self.is_stabe_diffusion_xl(state_dict):
|
|
||||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
|
||||||
elif self.is_stable_diffusion(state_dict):
|
|
||||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
|
||||||
elif self.is_sd_lora(state_dict):
|
|
||||||
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
|
||||||
elif self.is_beautiful_prompt(state_dict):
|
|
||||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
|
||||||
elif self.is_RIFE(state_dict):
|
|
||||||
self.load_RIFE(state_dict, file_path=file_path)
|
|
||||||
elif self.is_translator(state_dict):
|
|
||||||
self.load_translator(state_dict, file_path=file_path)
|
|
||||||
elif self.is_ipadapter(state_dict):
|
|
||||||
self.load_ipadapter(state_dict, file_path=file_path)
|
|
||||||
elif self.is_ipadapter_image_encoder(state_dict):
|
|
||||||
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
|
||||||
elif self.is_ipadapter_xl(state_dict):
|
|
||||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
|
||||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
|
||||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
|
||||||
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
|
||||||
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
|
||||||
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
|
||||||
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
|
||||||
elif self.is_hunyuan_dit(state_dict):
|
|
||||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
|
||||||
elif self.is_diffusers_vae(state_dict):
|
|
||||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
|
||||||
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
|
||||||
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
|
||||||
|
|
||||||
def load_models(self, file_path_list, lora_alphas=[]):
|
|
||||||
for file_path in file_path_list:
|
|
||||||
self.load_model(file_path, lora_alphas=lora_alphas)
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for component in self.model:
|
|
||||||
if isinstance(self.model[component], list):
|
|
||||||
for model in self.model[component]:
|
|
||||||
model.to(device)
|
|
||||||
else:
|
|
||||||
self.model[component].to(device)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def get_model_with_model_path(self, model_path):
|
|
||||||
for component in self.model_path:
|
|
||||||
if isinstance(self.model_path[component], str):
|
|
||||||
if os.path.samefile(self.model_path[component], model_path):
|
|
||||||
return self.model[component]
|
|
||||||
elif isinstance(self.model_path[component], list):
|
|
||||||
for i, model_path_ in enumerate(self.model_path[component]):
|
|
||||||
if os.path.samefile(model_path_, model_path):
|
|
||||||
return self.model[component][i]
|
|
||||||
raise ValueError(f"Please load model {model_path} before you use it.")
|
|
||||||
|
|
||||||
def __getattr__(self, __name):
|
|
||||||
if __name in self.model:
|
|
||||||
return self.model[__name]
|
|
||||||
else:
|
|
||||||
return super.__getattribute__(__name)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None):
|
|
||||||
if file_path.endswith(".safetensors"):
|
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
||||||
else:
|
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
||||||
state_dict = {}
|
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
state_dict[k] = f.get_tensor(k)
|
|
||||||
if torch_dtype is not None:
|
|
||||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
if torch_dtype is not None:
|
|
||||||
for i in state_dict:
|
|
||||||
if isinstance(state_dict[i], torch.Tensor):
|
|
||||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def search_parameter(param, state_dict):
|
|
||||||
for name, param_ in state_dict.items():
|
|
||||||
if param.numel() == param_.numel():
|
|
||||||
if param.shape == param_.shape:
|
|
||||||
if torch.dist(param, param_) < 1e-6:
|
|
||||||
return name
|
|
||||||
else:
|
|
||||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
|
||||||
return name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
||||||
matched_keys = set()
|
|
||||||
with torch.no_grad():
|
|
||||||
for name in source_state_dict:
|
|
||||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
||||||
if rename is not None:
|
|
||||||
print(f'"{name}": "{rename}",')
|
|
||||||
matched_keys.add(rename)
|
|
||||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
||||||
length = source_state_dict[name].shape[0] // 3
|
|
||||||
rename = []
|
|
||||||
for i in range(3):
|
|
||||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
||||||
if None not in rename:
|
|
||||||
print(f'"{name}": {rename},')
|
|
||||||
for rename_ in rename:
|
|
||||||
matched_keys.add(rename_)
|
|
||||||
for name in target_state_dict:
|
|
||||||
if name not in matched_keys:
|
|
||||||
print("Cannot find", name, target_state_dict[name].shape)
|
|
||||||
|
|||||||
408
diffsynth/models/cog_dit.py
Normal file
408
diffsynth/models/cog_dit.py
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .sd3_dit import TimestepEmbeddings
|
||||||
|
from .attention import Attention
|
||||||
|
from .utils import load_state_dict_from_folder
|
||||||
|
from .tiler import TileWorker2Dto3D
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogPatchify(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogAdaLayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dim_cond, single=False):
|
||||||
|
super().__init__()
|
||||||
|
self.single = single
|
||||||
|
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, prompt_emb, emb):
|
||||||
|
emb = self.linear(torch.nn.functional.silu(emb))
|
||||||
|
if self.single:
|
||||||
|
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
||||||
|
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
||||||
|
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
||||||
|
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
||||||
|
return hidden_states, prompt_emb, gate_a, gate_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogDiTBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dim_cond, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
||||||
|
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||||
|
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||||
|
|
||||||
|
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(self, x, freqs_cis):
|
||||||
|
cos, sin = freqs_cis # [S, D]
|
||||||
|
cos = cos[None, None]
|
||||||
|
sin = sin[None, None]
|
||||||
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||||
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
||||||
|
q = self.norm_q(q)
|
||||||
|
k = self.norm_k(k)
|
||||||
|
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
||||||
|
# Attention
|
||||||
|
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
||||||
|
hidden_states, prompt_emb, time_emb
|
||||||
|
)
|
||||||
|
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||||
|
attention_io = self.attn1(
|
||||||
|
attention_io,
|
||||||
|
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
||||||
|
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
||||||
|
|
||||||
|
# Feed forward
|
||||||
|
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
||||||
|
hidden_states, prompt_emb, time_emb
|
||||||
|
)
|
||||||
|
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||||
|
ff_io = self.ff(ff_io)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
||||||
|
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
||||||
|
|
||||||
|
return hidden_states, prompt_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogDiT(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.patchify = CogPatchify(16, 3072, 2)
|
||||||
|
self.time_embedder = TimestepEmbeddings(3072, 512)
|
||||||
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
|
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
||||||
|
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||||
|
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
||||||
|
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
||||||
|
tw = tgt_width
|
||||||
|
th = tgt_height
|
||||||
|
h, w = src
|
||||||
|
r = h / w
|
||||||
|
if r > (th / tw):
|
||||||
|
resize_height = th
|
||||||
|
resize_width = int(round(th / h * w))
|
||||||
|
else:
|
||||||
|
resize_width = tw
|
||||||
|
resize_height = int(round(tw / w * h))
|
||||||
|
|
||||||
|
crop_top = int(round((th - resize_height) / 2.0))
|
||||||
|
crop_left = int(round((tw - resize_width) / 2.0))
|
||||||
|
|
||||||
|
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||||
|
|
||||||
|
|
||||||
|
def get_3d_rotary_pos_embed(
|
||||||
|
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||||
|
):
|
||||||
|
start, stop = crops_coords
|
||||||
|
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||||
|
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||||
|
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||||
|
|
||||||
|
# Compute dimensions for each axis
|
||||||
|
dim_t = embed_dim // 4
|
||||||
|
dim_h = embed_dim // 8 * 3
|
||||||
|
dim_w = embed_dim // 8 * 3
|
||||||
|
|
||||||
|
# Temporal frequencies
|
||||||
|
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||||
|
grid_t = torch.from_numpy(grid_t).float()
|
||||||
|
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||||
|
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
|
# Spatial frequencies for height and width
|
||||||
|
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||||
|
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||||
|
grid_h = torch.from_numpy(grid_h).float()
|
||||||
|
grid_w = torch.from_numpy(grid_w).float()
|
||||||
|
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||||
|
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||||
|
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||||
|
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
|
# Broadcast and concatenate tensors along specified dimension
|
||||||
|
def broadcast(tensors, dim=-1):
|
||||||
|
num_tensors = len(tensors)
|
||||||
|
shape_lens = {len(t.shape) for t in tensors}
|
||||||
|
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||||
|
shape_len = list(shape_lens)[0]
|
||||||
|
dim = (dim + shape_len) if dim < 0 else dim
|
||||||
|
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||||
|
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||||
|
assert all(
|
||||||
|
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||||
|
), "invalid dimensions for broadcastable concatenation"
|
||||||
|
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||||
|
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||||
|
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||||
|
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||||
|
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||||
|
return torch.cat(tensors, dim=dim)
|
||||||
|
|
||||||
|
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||||
|
|
||||||
|
t, h, w, d = freqs.shape
|
||||||
|
freqs = freqs.view(t * h * w, d)
|
||||||
|
|
||||||
|
# Generate sine and cosine components
|
||||||
|
sin = freqs.sin()
|
||||||
|
cos = freqs.cos()
|
||||||
|
|
||||||
|
if use_real:
|
||||||
|
return cos, sin
|
||||||
|
else:
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_rotary_positional_embeddings(
|
||||||
|
self,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
num_frames: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
grid_height = height // 2
|
||||||
|
grid_width = width // 2
|
||||||
|
base_size_width = 720 // (8 * 2)
|
||||||
|
base_size_height = 480 // (8 * 2)
|
||||||
|
|
||||||
|
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
||||||
|
(grid_height, grid_width), base_size_width, base_size_height
|
||||||
|
)
|
||||||
|
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
||||||
|
embed_dim=64,
|
||||||
|
crops_coords=grid_crops_coords,
|
||||||
|
grid_size=(grid_height, grid_width),
|
||||||
|
temporal_size=num_frames,
|
||||||
|
use_real=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs_cos = freqs_cos.to(device=device)
|
||||||
|
freqs_sin = freqs_sin.to(device=device)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(self, hidden_states, height, width):
|
||||||
|
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(self, T, H, W, dtype, device, is_bound):
|
||||||
|
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
||||||
|
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
||||||
|
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
||||||
|
border_width = (H + W) // 4
|
||||||
|
pad = torch.ones_like(h) * border_width
|
||||||
|
mask = torch.stack([
|
||||||
|
pad if is_bound[0] else t + 1,
|
||||||
|
pad if is_bound[1] else T - t,
|
||||||
|
pad if is_bound[2] else h + 1,
|
||||||
|
pad if is_bound[3] else H - h,
|
||||||
|
pad if is_bound[4] else w + 1,
|
||||||
|
pad if is_bound[5] else W - w
|
||||||
|
]).min(dim=0).values
|
||||||
|
mask = mask.clip(1, border_width)
|
||||||
|
mask = (mask / border_width).to(dtype=dtype, device=device)
|
||||||
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
|
||||||
|
# Split tasks
|
||||||
|
tasks = []
|
||||||
|
for h in range(0, H, tile_stride):
|
||||||
|
for w in range(0, W, tile_stride):
|
||||||
|
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||||
|
continue
|
||||||
|
h_, w_ = h + tile_size, w + tile_size
|
||||||
|
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
||||||
|
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
||||||
|
tasks.append((h, h_, w, w_))
|
||||||
|
|
||||||
|
# Run
|
||||||
|
for hl, hr, wl, wr in tasks:
|
||||||
|
mask = self.build_mask(
|
||||||
|
value.shape[2], (hr-hl), (wr-wl),
|
||||||
|
hidden_states.dtype, hidden_states.device,
|
||||||
|
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
||||||
|
)
|
||||||
|
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
||||||
|
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
||||||
|
weight[:, :, :, hl:hr, wl:wr] += mask
|
||||||
|
value = value / weight
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
||||||
|
if tiled:
|
||||||
|
return TileWorker2Dto3D().tiled_forward(
|
||||||
|
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||||
|
model_input=hidden_states,
|
||||||
|
tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
||||||
|
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
||||||
|
)
|
||||||
|
num_frames, height, width = hidden_states.shape[-3:]
|
||||||
|
if image_rotary_emb is None:
|
||||||
|
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
||||||
|
hidden_states = self.patchify(hidden_states)
|
||||||
|
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||||
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
|
hidden_states = self.norm_final(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||||
|
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return CogDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
||||||
|
model = CogDiT().to(torch_dtype)
|
||||||
|
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
||||||
|
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogDiTStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"patch_embed.proj.weight": "patchify.proj.weight",
|
||||||
|
"patch_embed.proj.bias": "patchify.proj.bias",
|
||||||
|
"patch_embed.text_proj.weight": "context_embedder.weight",
|
||||||
|
"patch_embed.text_proj.bias": "context_embedder.bias",
|
||||||
|
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
||||||
|
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
||||||
|
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||||
|
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||||
|
|
||||||
|
"norm_final.weight": "norm_final.weight",
|
||||||
|
"norm_final.bias": "norm_final.bias",
|
||||||
|
"norm_out.linear.weight": "norm_out.linear.weight",
|
||||||
|
"norm_out.linear.bias": "norm_out.linear.bias",
|
||||||
|
"norm_out.norm.weight": "norm_out.norm.weight",
|
||||||
|
"norm_out.norm.bias": "norm_out.norm.bias",
|
||||||
|
"proj_out.weight": "proj_out.weight",
|
||||||
|
"proj_out.bias": "proj_out.bias",
|
||||||
|
}
|
||||||
|
suffix_dict = {
|
||||||
|
"norm1.linear.weight": "norm1.linear.weight",
|
||||||
|
"norm1.linear.bias": "norm1.linear.bias",
|
||||||
|
"norm1.norm.weight": "norm1.norm.weight",
|
||||||
|
"norm1.norm.bias": "norm1.norm.bias",
|
||||||
|
"attn1.norm_q.weight": "norm_q.weight",
|
||||||
|
"attn1.norm_q.bias": "norm_q.bias",
|
||||||
|
"attn1.norm_k.weight": "norm_k.weight",
|
||||||
|
"attn1.norm_k.bias": "norm_k.bias",
|
||||||
|
"attn1.to_q.weight": "attn1.to_q.weight",
|
||||||
|
"attn1.to_q.bias": "attn1.to_q.bias",
|
||||||
|
"attn1.to_k.weight": "attn1.to_k.weight",
|
||||||
|
"attn1.to_k.bias": "attn1.to_k.bias",
|
||||||
|
"attn1.to_v.weight": "attn1.to_v.weight",
|
||||||
|
"attn1.to_v.bias": "attn1.to_v.bias",
|
||||||
|
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
||||||
|
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
||||||
|
"norm2.linear.weight": "norm2.linear.weight",
|
||||||
|
"norm2.linear.bias": "norm2.linear.bias",
|
||||||
|
"norm2.norm.weight": "norm2.norm.weight",
|
||||||
|
"norm2.norm.bias": "norm2.norm.bias",
|
||||||
|
"ff.net.0.proj.weight": "ff.0.weight",
|
||||||
|
"ff.net.0.proj.bias": "ff.0.bias",
|
||||||
|
"ff.net.2.weight": "ff.2.weight",
|
||||||
|
"ff.net.2.bias": "ff.2.bias",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in rename_dict:
|
||||||
|
if name == "patch_embed.proj.weight":
|
||||||
|
param = param.unsqueeze(2)
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
else:
|
||||||
|
names = name.split(".")
|
||||||
|
if names[0] == "transformer_blocks":
|
||||||
|
suffix = ".".join(names[2:])
|
||||||
|
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
518
diffsynth/models/cog_vae.py
Normal file
518
diffsynth/models/cog_vae.py
Normal file
@@ -0,0 +1,518 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .tiler import TileWorker2Dto3D
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample3D(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
stride: int = 2,
|
||||||
|
padding: int = 0,
|
||||||
|
compress_time: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.compress_time = compress_time
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.compress_time:
|
||||||
|
batch_size, channels, frames, height, width = x.shape
|
||||||
|
|
||||||
|
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||||
|
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||||
|
|
||||||
|
if x.shape[-1] % 2 == 1:
|
||||||
|
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||||
|
if x_rest.shape[-1] > 0:
|
||||||
|
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||||
|
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||||
|
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||||
|
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||||
|
else:
|
||||||
|
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||||
|
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||||
|
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||||
|
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||||
|
|
||||||
|
# Pad the tensor
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
batch_size, channels, frames, height, width = x.shape
|
||||||
|
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||||
|
x = self.conv(x)
|
||||||
|
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||||
|
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample3D(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: int = 1,
|
||||||
|
compress_time: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.compress_time = compress_time
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.compress_time:
|
||||||
|
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||||
|
# split first frame
|
||||||
|
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||||
|
|
||||||
|
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
|
||||||
|
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
|
||||||
|
x_first = x_first[:, :, None, :, :]
|
||||||
|
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||||
|
elif inputs.shape[2] > 1:
|
||||||
|
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||||
|
else:
|
||||||
|
inputs = inputs.squeeze(2)
|
||||||
|
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||||
|
inputs = inputs[:, :, None, :, :]
|
||||||
|
else:
|
||||||
|
# only interpolate 2D
|
||||||
|
b, c, t, h, w = inputs.shape
|
||||||
|
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||||
|
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
b, c, t, h, w = inputs.shape
|
||||||
|
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
inputs = self.conv(inputs)
|
||||||
|
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXSpatialNorm3D(torch.nn.Module):
|
||||||
|
def __init__(self, f_channels, zq_channels, groups):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||||
|
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||||
|
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||||
|
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||||
|
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||||
|
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||||
|
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||||
|
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
|
||||||
|
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
|
||||||
|
zq = torch.cat([z_first, z_rest], dim=2)
|
||||||
|
else:
|
||||||
|
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
|
||||||
|
|
||||||
|
norm_f = self.norm_layer(f)
|
||||||
|
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||||
|
return new_f
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Resnet3DBlock(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
|
||||||
|
super().__init__()
|
||||||
|
self.nonlinearity = torch.nn.SiLU()
|
||||||
|
if spatial_norm_dim is None:
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||||
|
else:
|
||||||
|
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
|
||||||
|
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
|
||||||
|
|
||||||
|
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||||
|
|
||||||
|
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||||
|
|
||||||
|
if in_channels != out_channels:
|
||||||
|
if use_conv_shortcut:
|
||||||
|
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, zq):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.conv_shortcut(residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CachedConv3d(torch.nn.Conv3d):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||||
|
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.cached_tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
self.cached_tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
|
||||||
|
if use_cache:
|
||||||
|
if self.cached_tensor is None:
|
||||||
|
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
|
||||||
|
input = torch.concat([self.cached_tensor, input], dim=2)
|
||||||
|
self.cached_tensor = input[:, :, -2:]
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogVAEDecoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.7
|
||||||
|
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
Resnet3DBlock(512, 512, 16, 32),
|
||||||
|
Resnet3DBlock(512, 512, 16, 32),
|
||||||
|
Resnet3DBlock(512, 512, 16, 32),
|
||||||
|
Resnet3DBlock(512, 512, 16, 32),
|
||||||
|
Resnet3DBlock(512, 512, 16, 32),
|
||||||
|
Resnet3DBlock(512, 512, 16, 32),
|
||||||
|
Upsample3D(512, 512, compress_time=True),
|
||||||
|
Resnet3DBlock(512, 256, 16, 32),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Upsample3D(256, 256, compress_time=True),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Resnet3DBlock(256, 256, 16, 32),
|
||||||
|
Upsample3D(256, 256, compress_time=False),
|
||||||
|
Resnet3DBlock(256, 128, 16, 32),
|
||||||
|
Resnet3DBlock(128, 128, 16, 32),
|
||||||
|
Resnet3DBlock(128, 128, 16, 32),
|
||||||
|
Resnet3DBlock(128, 128, 16, 32),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = sample / self.scaling_factor
|
||||||
|
hidden_states = self.conv_in(sample)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states = block(hidden_states, sample)
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states, sample)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||||
|
if tiled:
|
||||||
|
B, C, T, H, W = sample.shape
|
||||||
|
return TileWorker2Dto3D().tiled_forward(
|
||||||
|
forward_fn=lambda x: self.decode_small_video(x),
|
||||||
|
model_input=sample,
|
||||||
|
tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||||
|
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||||
|
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
|
||||||
|
progress_bar=progress_bar
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.decode_small_video(sample)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_small_video(self, sample):
|
||||||
|
B, C, T, H, W = sample.shape
|
||||||
|
computation_device = self.conv_in.weight.device
|
||||||
|
computation_dtype = self.conv_in.weight.dtype
|
||||||
|
value = []
|
||||||
|
for i in range(T//2):
|
||||||
|
tl = i*2 + T%2 - (T%2 and i==0)
|
||||||
|
tr = i*2 + 2 + T%2
|
||||||
|
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||||
|
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||||
|
value.append(model_output)
|
||||||
|
value = torch.concat(value, dim=2)
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, CachedConv3d):
|
||||||
|
module.clear_cache()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return CogVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogVAEEncoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.7
|
||||||
|
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
Resnet3DBlock(128, 128, None, 32),
|
||||||
|
Resnet3DBlock(128, 128, None, 32),
|
||||||
|
Resnet3DBlock(128, 128, None, 32),
|
||||||
|
Downsample3D(128, 128, compress_time=True),
|
||||||
|
Resnet3DBlock(128, 256, None, 32),
|
||||||
|
Resnet3DBlock(256, 256, None, 32),
|
||||||
|
Resnet3DBlock(256, 256, None, 32),
|
||||||
|
Downsample3D(256, 256, compress_time=True),
|
||||||
|
Resnet3DBlock(256, 256, None, 32),
|
||||||
|
Resnet3DBlock(256, 256, None, 32),
|
||||||
|
Resnet3DBlock(256, 256, None, 32),
|
||||||
|
Downsample3D(256, 256, compress_time=False),
|
||||||
|
Resnet3DBlock(256, 512, None, 32),
|
||||||
|
Resnet3DBlock(512, 512, None, 32),
|
||||||
|
Resnet3DBlock(512, 512, None, 32),
|
||||||
|
Resnet3DBlock(512, 512, None, 32),
|
||||||
|
Resnet3DBlock(512, 512, None, 32),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
hidden_states = self.conv_in(sample)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states = block(hidden_states, sample)
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)[:, :16]
|
||||||
|
hidden_states = hidden_states * self.scaling_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||||
|
if tiled:
|
||||||
|
B, C, T, H, W = sample.shape
|
||||||
|
return TileWorker2Dto3D().tiled_forward(
|
||||||
|
forward_fn=lambda x: self.encode_small_video(x),
|
||||||
|
model_input=sample,
|
||||||
|
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
|
||||||
|
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||||
|
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||||
|
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
|
||||||
|
progress_bar=progress_bar
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.encode_small_video(sample)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_small_video(self, sample):
|
||||||
|
B, C, T, H, W = sample.shape
|
||||||
|
computation_device = self.conv_in.weight.device
|
||||||
|
computation_dtype = self.conv_in.weight.dtype
|
||||||
|
value = []
|
||||||
|
for i in range(T//8):
|
||||||
|
t = i*8 + T%2 - (T%2 and i==0)
|
||||||
|
t_ = i*8 + 8 + T%2
|
||||||
|
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||||
|
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||||
|
value.append(model_output)
|
||||||
|
value = torch.concat(value, dim=2)
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, CachedConv3d):
|
||||||
|
module.clear_cache()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return CogVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogVAEEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"encoder.conv_in.conv.weight": "conv_in.weight",
|
||||||
|
"encoder.conv_in.conv.bias": "conv_in.bias",
|
||||||
|
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
|
||||||
|
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
|
||||||
|
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
|
||||||
|
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
|
||||||
|
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||||
|
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||||
|
"encoder.norm_out.weight": "norm_out.weight",
|
||||||
|
"encoder.norm_out.bias": "norm_out.bias",
|
||||||
|
"encoder.conv_out.conv.weight": "conv_out.weight",
|
||||||
|
"encoder.conv_out.conv.bias": "conv_out.bias",
|
||||||
|
}
|
||||||
|
prefix_dict = {
|
||||||
|
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
|
||||||
|
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
|
||||||
|
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
|
||||||
|
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
|
||||||
|
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
|
||||||
|
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
|
||||||
|
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
|
||||||
|
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
|
||||||
|
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
|
||||||
|
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
|
||||||
|
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
|
||||||
|
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
|
||||||
|
"encoder.mid_block.resnets.0.": "blocks.15.",
|
||||||
|
"encoder.mid_block.resnets.1.": "blocks.16.",
|
||||||
|
}
|
||||||
|
suffix_dict = {
|
||||||
|
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||||
|
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||||
|
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||||
|
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||||
|
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||||
|
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||||
|
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||||
|
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||||
|
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||||
|
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||||
|
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||||
|
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||||
|
"conv1.conv.weight": "conv1.weight",
|
||||||
|
"conv1.conv.bias": "conv1.bias",
|
||||||
|
"conv2.conv.weight": "conv2.weight",
|
||||||
|
"conv2.conv.bias": "conv2.bias",
|
||||||
|
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||||
|
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||||
|
"norm1.weight": "norm1.weight",
|
||||||
|
"norm1.bias": "norm1.bias",
|
||||||
|
"norm2.weight": "norm2.weight",
|
||||||
|
"norm2.bias": "norm2.bias",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in rename_dict:
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
else:
|
||||||
|
for prefix in prefix_dict:
|
||||||
|
if name.startswith(prefix):
|
||||||
|
suffix = name[len(prefix):]
|
||||||
|
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CogVAEDecoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"decoder.conv_in.conv.weight": "conv_in.weight",
|
||||||
|
"decoder.conv_in.conv.bias": "conv_in.bias",
|
||||||
|
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
|
||||||
|
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
|
||||||
|
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||||
|
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||||
|
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
|
||||||
|
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
|
||||||
|
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
|
||||||
|
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
|
||||||
|
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
|
||||||
|
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
|
||||||
|
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
|
||||||
|
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
|
||||||
|
"decoder.conv_out.conv.weight": "conv_out.weight",
|
||||||
|
"decoder.conv_out.conv.bias": "conv_out.bias"
|
||||||
|
}
|
||||||
|
prefix_dict = {
|
||||||
|
"decoder.mid_block.resnets.0.": "blocks.0.",
|
||||||
|
"decoder.mid_block.resnets.1.": "blocks.1.",
|
||||||
|
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
|
||||||
|
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
|
||||||
|
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
|
||||||
|
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
|
||||||
|
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
|
||||||
|
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
|
||||||
|
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
|
||||||
|
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
|
||||||
|
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
|
||||||
|
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
|
||||||
|
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
|
||||||
|
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
|
||||||
|
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
|
||||||
|
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
|
||||||
|
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
|
||||||
|
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
|
||||||
|
}
|
||||||
|
suffix_dict = {
|
||||||
|
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||||
|
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||||
|
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||||
|
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||||
|
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||||
|
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||||
|
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||||
|
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||||
|
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||||
|
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||||
|
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||||
|
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||||
|
"conv1.conv.weight": "conv1.weight",
|
||||||
|
"conv1.conv.bias": "conv1.bias",
|
||||||
|
"conv2.conv.weight": "conv2.weight",
|
||||||
|
"conv2.conv.bias": "conv2.bias",
|
||||||
|
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||||
|
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in rename_dict:
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
else:
|
||||||
|
for prefix in prefix_dict:
|
||||||
|
if name.startswith(prefix):
|
||||||
|
suffix = name[len(prefix):]
|
||||||
|
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
|
|
||||||
111
diffsynth/models/downloader.py
Normal file
111
diffsynth/models/downloader.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
import os, shutil
|
||||||
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
from typing import List
|
||||||
|
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
||||||
|
|
||||||
|
|
||||||
|
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
file_name = os.path.basename(origin_file_path)
|
||||||
|
if file_name in os.listdir(local_dir):
|
||||||
|
print(f" {file_name} has been already in {local_dir}.")
|
||||||
|
else:
|
||||||
|
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||||
|
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||||
|
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||||
|
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||||
|
if downloaded_file_path != target_file_path:
|
||||||
|
shutil.move(downloaded_file_path, target_file_path)
|
||||||
|
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
file_name = os.path.basename(origin_file_path)
|
||||||
|
if file_name in os.listdir(local_dir):
|
||||||
|
print(f" {file_name} has been already in {local_dir}.")
|
||||||
|
else:
|
||||||
|
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||||
|
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||||
|
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||||
|
target_file_path = os.path.join(local_dir, file_name)
|
||||||
|
if downloaded_file_path != target_file_path:
|
||||||
|
shutil.move(downloaded_file_path, target_file_path)
|
||||||
|
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||||
|
|
||||||
|
|
||||||
|
Preset_model_website: TypeAlias = Literal[
|
||||||
|
"HuggingFace",
|
||||||
|
"ModelScope",
|
||||||
|
]
|
||||||
|
website_to_preset_models = {
|
||||||
|
"HuggingFace": preset_models_on_huggingface,
|
||||||
|
"ModelScope": preset_models_on_modelscope,
|
||||||
|
}
|
||||||
|
website_to_download_fn = {
|
||||||
|
"HuggingFace": download_from_huggingface,
|
||||||
|
"ModelScope": download_from_modelscope,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download_customized_models(
|
||||||
|
model_id,
|
||||||
|
origin_file_path,
|
||||||
|
local_dir,
|
||||||
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
|
):
|
||||||
|
downloaded_files = []
|
||||||
|
for website in downloading_priority:
|
||||||
|
# Check if the file is downloaded.
|
||||||
|
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||||
|
if file_to_download in downloaded_files:
|
||||||
|
continue
|
||||||
|
# Download
|
||||||
|
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||||
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
|
downloaded_files.append(file_to_download)
|
||||||
|
return downloaded_files
|
||||||
|
|
||||||
|
|
||||||
|
def download_models(
|
||||||
|
model_id_list: List[Preset_model_id] = [],
|
||||||
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
|
):
|
||||||
|
print(f"Downloading models: {model_id_list}")
|
||||||
|
downloaded_files = []
|
||||||
|
load_files = []
|
||||||
|
|
||||||
|
for model_id in model_id_list:
|
||||||
|
for website in downloading_priority:
|
||||||
|
if model_id in website_to_preset_models[website]:
|
||||||
|
|
||||||
|
# Parse model metadata
|
||||||
|
model_metadata = website_to_preset_models[website][model_id]
|
||||||
|
if isinstance(model_metadata, list):
|
||||||
|
file_data = model_metadata
|
||||||
|
else:
|
||||||
|
file_data = model_metadata.get("file_list", [])
|
||||||
|
|
||||||
|
# Try downloading the model from this website.
|
||||||
|
model_files = []
|
||||||
|
for model_id, origin_file_path, local_dir in file_data:
|
||||||
|
# Check if the file is downloaded.
|
||||||
|
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||||
|
if file_to_download in downloaded_files:
|
||||||
|
continue
|
||||||
|
# Download
|
||||||
|
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||||
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
|
downloaded_files.append(file_to_download)
|
||||||
|
model_files.append(file_to_download)
|
||||||
|
|
||||||
|
# If the model is successfully downloaded, break.
|
||||||
|
if len(model_files) > 0:
|
||||||
|
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
||||||
|
model_files = model_metadata["load_path"]
|
||||||
|
load_files.extend(model_files)
|
||||||
|
break
|
||||||
|
|
||||||
|
return load_files
|
||||||
331
diffsynth/models/flux_controlnet.py
Normal file
331
diffsynth/models/flux_controlnet.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||||
|
from .utils import hash_state_dict_keys, init_weights_on_device
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxControlNet(torch.nn.Module):
|
||||||
|
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
|
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
||||||
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
||||||
|
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
||||||
|
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
||||||
|
|
||||||
|
self.mode_dict = mode_dict
|
||||||
|
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
||||||
|
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image_ids(self, latents):
|
||||||
|
batch_size, _, height, width = latents.shape
|
||||||
|
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||||
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||||
|
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||||
|
|
||||||
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||||
|
|
||||||
|
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||||
|
latent_image_ids = latent_image_ids.reshape(
|
||||||
|
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||||
|
)
|
||||||
|
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
return latent_image_ids
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, hidden_states):
|
||||||
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
||||||
|
if len(res_stack) == 0:
|
||||||
|
return [torch.zeros_like(hidden_states)] * num_blocks
|
||||||
|
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
||||||
|
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
||||||
|
return aligned_res_stack
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
controlnet_conditioning,
|
||||||
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||||
|
processor_id=None,
|
||||||
|
tiled=False, tile_size=128, tile_stride=64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if image_ids is None:
|
||||||
|
image_ids = self.prepare_image_ids(hidden_states)
|
||||||
|
|
||||||
|
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||||
|
if self.guidance_embedder is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||||
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
|
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
||||||
|
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
||||||
|
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
||||||
|
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
||||||
|
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
||||||
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|
||||||
|
hidden_states = self.patchify(hidden_states)
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
||||||
|
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
||||||
|
|
||||||
|
controlnet_res_stack = []
|
||||||
|
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||||
|
controlnet_res_stack.append(controlnet_block(hidden_states))
|
||||||
|
|
||||||
|
controlnet_single_res_stack = []
|
||||||
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
|
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||||
|
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
||||||
|
|
||||||
|
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
||||||
|
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
||||||
|
|
||||||
|
return controlnet_res_stack, controlnet_single_res_stack
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxControlNetStateDictConverter()
|
||||||
|
|
||||||
|
def quantize(self):
|
||||||
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def cast_weight(s, input=None, dtype=None, device=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
bias = None
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
bias = cast_to(s.bias, bias_dtype, device)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
class quantized_layer:
|
||||||
|
class QLinear(torch.nn.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self,input,**kwargs):
|
||||||
|
weight,bias= cast_bias_weight(self,input)
|
||||||
|
return torch.nn.functional.linear(input,weight,bias)
|
||||||
|
|
||||||
|
class QRMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, module):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self,hidden_states,**kwargs):
|
||||||
|
weight= cast_weight(self.module,hidden_states)
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype) * weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class QEmbedding(torch.nn.Embedding):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self,input,**kwargs):
|
||||||
|
weight= cast_weight(self,input)
|
||||||
|
return torch.nn.functional.embedding(
|
||||||
|
input, weight, self.padding_idx, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
|
||||||
|
def replace_layer(model):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if isinstance(module,quantized_layer.QRMSNorm):
|
||||||
|
continue
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
with init_weights_on_device():
|
||||||
|
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
||||||
|
new_layer.weight = module.weight
|
||||||
|
if module.bias is not None:
|
||||||
|
new_layer.bias = module.bias
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
if hasattr(module,"quantized"):
|
||||||
|
continue
|
||||||
|
module.quantized= True
|
||||||
|
new_layer = quantized_layer.QRMSNorm(module)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module,torch.nn.Embedding):
|
||||||
|
rows, cols = module.weight.shape
|
||||||
|
new_layer = quantized_layer.QEmbedding(
|
||||||
|
num_embeddings=rows,
|
||||||
|
embedding_dim=cols,
|
||||||
|
_weight=module.weight,
|
||||||
|
# _freeze=module.freeze,
|
||||||
|
padding_idx=module.padding_idx,
|
||||||
|
max_norm=module.max_norm,
|
||||||
|
norm_type=module.norm_type,
|
||||||
|
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||||
|
sparse=module.sparse)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
else:
|
||||||
|
replace_layer(module)
|
||||||
|
|
||||||
|
replace_layer(self)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxControlNetStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
hash_value = hash_state_dict_keys(state_dict)
|
||||||
|
global_rename_dict = {
|
||||||
|
"context_embedder": "context_embedder",
|
||||||
|
"x_embedder": "x_embedder",
|
||||||
|
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||||
|
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||||
|
"norm_out.linear": "final_norm_out.linear",
|
||||||
|
"proj_out": "final_proj_out",
|
||||||
|
}
|
||||||
|
rename_dict = {
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"norm1.linear": "norm1_a.linear",
|
||||||
|
"norm1_context.linear": "norm1_b.linear",
|
||||||
|
"attn.to_q": "attn.a_to_q",
|
||||||
|
"attn.to_k": "attn.a_to_k",
|
||||||
|
"attn.to_v": "attn.a_to_v",
|
||||||
|
"attn.to_out.0": "attn.a_to_out",
|
||||||
|
"attn.add_q_proj": "attn.b_to_q",
|
||||||
|
"attn.add_k_proj": "attn.b_to_k",
|
||||||
|
"attn.add_v_proj": "attn.b_to_v",
|
||||||
|
"attn.to_add_out": "attn.b_to_out",
|
||||||
|
"ff.net.0.proj": "ff_a.0",
|
||||||
|
"ff.net.2": "ff_a.2",
|
||||||
|
"ff_context.net.0.proj": "ff_b.0",
|
||||||
|
"ff_context.net.2": "ff_b.2",
|
||||||
|
"attn.norm_q": "attn.norm_q_a",
|
||||||
|
"attn.norm_k": "attn.norm_k_a",
|
||||||
|
"attn.norm_added_q": "attn.norm_q_b",
|
||||||
|
"attn.norm_added_k": "attn.norm_k_b",
|
||||||
|
}
|
||||||
|
rename_dict_single = {
|
||||||
|
"attn.to_q": "a_to_q",
|
||||||
|
"attn.to_k": "a_to_k",
|
||||||
|
"attn.to_v": "a_to_v",
|
||||||
|
"attn.norm_q": "norm_q_a",
|
||||||
|
"attn.norm_k": "norm_k_a",
|
||||||
|
"norm.linear": "norm.linear",
|
||||||
|
"proj_mlp": "proj_in_besides_attn",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.endswith(".weight") or name.endswith(".bias"):
|
||||||
|
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||||
|
prefix = name[:-len(suffix)]
|
||||||
|
if prefix in global_rename_dict:
|
||||||
|
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||||
|
elif prefix.startswith("transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif prefix.startswith("single_transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "single_blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict_single:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if ".proj_in_besides_attn." in name:
|
||||||
|
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||||
|
state_dict_[name],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||||
|
state_dict_.pop(name)
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
||||||
|
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
||||||
|
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
||||||
|
extra_kwargs = {"num_single_blocks": 0}
|
||||||
|
elif hash_value == "52357cb26250681367488a8954c271e8":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||||
|
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||||
|
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||||
|
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
||||||
|
else:
|
||||||
|
extra_kwargs = {}
|
||||||
|
return state_dict_, extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
746
diffsynth/models/flux_dit.py
Normal file
746
diffsynth/models/flux_dit.py
Normal file
@@ -0,0 +1,746 @@
|
|||||||
|
import torch
|
||||||
|
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||||
|
from einops import rearrange
|
||||||
|
from .tiler import TileWorker
|
||||||
|
from .utils import init_weights_on_device
|
||||||
|
|
||||||
|
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class RoPEEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, dim, theta, axes_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
|
||||||
|
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
|
assert dim % 2 == 0, "The dimension must be even."
|
||||||
|
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
|
batch_size, seq_length = pos.shape
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
cos_out = torch.cos(out)
|
||||||
|
sin_out = torch.sin(out)
|
||||||
|
|
||||||
|
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||||
|
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, ids):
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxJointAttention(torch.nn.Module):
|
||||||
|
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.only_out_a = only_out_a
|
||||||
|
|
||||||
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||||
|
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||||
|
|
||||||
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||||
|
if not only_out_a:
|
||||||
|
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(self, xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
batch_size = hidden_states_a.shape[0]
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
qkv_a = self.a_to_qkv(hidden_states_a)
|
||||||
|
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
||||||
|
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||||
|
|
||||||
|
# Part B
|
||||||
|
qkv_b = self.b_to_qkv(hidden_states_b)
|
||||||
|
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
||||||
|
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
||||||
|
|
||||||
|
q = torch.concat([q_b, q_a], dim=2)
|
||||||
|
k = torch.concat([k_b, k_a], dim=2)
|
||||||
|
v = torch.concat([v_b, v_a], dim=2)
|
||||||
|
|
||||||
|
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||||
|
if ipadapter_kwargs_list is not None:
|
||||||
|
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||||
|
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||||
|
if self.only_out_a:
|
||||||
|
return hidden_states_a
|
||||||
|
else:
|
||||||
|
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxJointTransformerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_a = AdaLayerNorm(dim)
|
||||||
|
self.norm1_b = AdaLayerNorm(dim)
|
||||||
|
|
||||||
|
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||||
|
|
||||||
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_a = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_b = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||||
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||||
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||||
|
|
||||||
|
# Part B
|
||||||
|
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||||
|
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||||
|
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxSingleAttention(torch.nn.Module):
|
||||||
|
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||||
|
|
||||||
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(self, xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, image_rotary_emb):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
qkv_a = self.a_to_qkv(hidden_states)
|
||||||
|
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
||||||
|
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||||
|
|
||||||
|
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormSingle(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = torch.nn.SiLU()
|
||||||
|
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||||
|
return x, gate_msa
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = dim // num_attention_heads
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.norm = AdaLayerNormSingle(dim)
|
||||||
|
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
||||||
|
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(self, xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
||||||
|
|
||||||
|
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
if ipadapter_kwargs_list is not None:
|
||||||
|
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
residual = hidden_states_a
|
||||||
|
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||||
|
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||||
|
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||||
|
|
||||||
|
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||||
|
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||||
|
|
||||||
|
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||||
|
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||||
|
hidden_states_a = residual + hidden_states_a
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormContinuous(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = torch.nn.SiLU()
|
||||||
|
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
||||||
|
|
||||||
|
def forward(self, x, conditioning):
|
||||||
|
emb = self.linear(self.silu(conditioning))
|
||||||
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxDiT(torch.nn.Module):
|
||||||
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
|
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||||
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||||
|
|
||||||
|
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||||
|
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, hidden_states):
|
||||||
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(self, hidden_states, height, width):
|
||||||
|
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image_ids(self, latents):
|
||||||
|
batch_size, _, height, width = latents.shape
|
||||||
|
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||||
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||||
|
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||||
|
|
||||||
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||||
|
|
||||||
|
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||||
|
latent_image_ids = latent_image_ids.reshape(
|
||||||
|
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||||
|
)
|
||||||
|
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
return latent_image_ids
|
||||||
|
|
||||||
|
|
||||||
|
def tiled_forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||||
|
tile_size=128, tile_stride=64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
||||||
|
hidden_states,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=hidden_states.device,
|
||||||
|
tile_dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||||
|
N = len(entity_masks)
|
||||||
|
batch_size = entity_masks[0].shape[0]
|
||||||
|
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||||
|
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||||
|
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||||
|
|
||||||
|
image_start = N * prompt_seq_len
|
||||||
|
image_end = N * prompt_seq_len + image_seq_len
|
||||||
|
# prompt-image mask
|
||||||
|
for i in range(N):
|
||||||
|
prompt_start = i * prompt_seq_len
|
||||||
|
prompt_end = (i + 1) * prompt_seq_len
|
||||||
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||||
|
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||||
|
# prompt update with image
|
||||||
|
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||||
|
# image update with prompt
|
||||||
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||||
|
# prompt-prompt mask
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(N):
|
||||||
|
if i != j:
|
||||||
|
prompt_start_i = i * prompt_seq_len
|
||||||
|
prompt_end_i = (i + 1) * prompt_seq_len
|
||||||
|
prompt_start_j = j * prompt_seq_len
|
||||||
|
prompt_end_j = (j + 1) * prompt_seq_len
|
||||||
|
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||||
|
|
||||||
|
attention_mask = attention_mask.float()
|
||||||
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
|
attention_mask[attention_mask == 1] = 0
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||||
|
repeat_dim = hidden_states.shape[1]
|
||||||
|
max_masks = 0
|
||||||
|
attention_mask = None
|
||||||
|
prompt_embs = [prompt_emb]
|
||||||
|
if entity_masks is not None:
|
||||||
|
# entity_masks
|
||||||
|
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||||
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
# global mask
|
||||||
|
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
entity_masks = entity_masks + [global_mask] # append global to last
|
||||||
|
# attention mask
|
||||||
|
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||||
|
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
# embds: n_masks * b * seq * d
|
||||||
|
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
prompt_embs = local_embs + prompt_embs # append global to last
|
||||||
|
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||||
|
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||||
|
|
||||||
|
# positional embedding
|
||||||
|
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||||
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
return prompt_emb, image_rotary_emb, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||||
|
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(
|
||||||
|
hidden_states,
|
||||||
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||||
|
tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_ids is None:
|
||||||
|
image_ids = self.prepare_image_ids(hidden_states)
|
||||||
|
|
||||||
|
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||||
|
if self.guidance_embedder is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||||
|
|
||||||
|
height, width = hidden_states.shape[-2:]
|
||||||
|
hidden_states = self.patchify(hidden_states)
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
|
||||||
|
if entity_prompt_emb is not None and entity_masks is not None:
|
||||||
|
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||||
|
else:
|
||||||
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
|
for block in self.single_blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||||
|
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||||
|
|
||||||
|
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||||
|
hidden_states = self.final_proj_out(hidden_states)
|
||||||
|
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(self):
|
||||||
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def cast_weight(s, input=None, dtype=None, device=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
bias = None
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
bias = cast_to(s.bias, bias_dtype, device)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
class quantized_layer:
|
||||||
|
class Linear(torch.nn.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self,input,**kwargs):
|
||||||
|
weight,bias= cast_bias_weight(self,input)
|
||||||
|
return torch.nn.functional.linear(input,weight,bias)
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, module):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self,hidden_states,**kwargs):
|
||||||
|
weight= cast_weight(self.module,hidden_states)
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype) * weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def replace_layer(model):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
with init_weights_on_device():
|
||||||
|
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||||
|
new_layer.weight = module.weight
|
||||||
|
if module.bias is not None:
|
||||||
|
new_layer.bias = module.bias
|
||||||
|
# del module
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
if hasattr(module,"quantized"):
|
||||||
|
continue
|
||||||
|
module.quantized= True
|
||||||
|
new_layer = quantized_layer.RMSNorm(module)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
else:
|
||||||
|
replace_layer(module)
|
||||||
|
|
||||||
|
replace_layer(self)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxDiTStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
global_rename_dict = {
|
||||||
|
"context_embedder": "context_embedder",
|
||||||
|
"x_embedder": "x_embedder",
|
||||||
|
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||||
|
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||||
|
"norm_out.linear": "final_norm_out.linear",
|
||||||
|
"proj_out": "final_proj_out",
|
||||||
|
}
|
||||||
|
rename_dict = {
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"norm1.linear": "norm1_a.linear",
|
||||||
|
"norm1_context.linear": "norm1_b.linear",
|
||||||
|
"attn.to_q": "attn.a_to_q",
|
||||||
|
"attn.to_k": "attn.a_to_k",
|
||||||
|
"attn.to_v": "attn.a_to_v",
|
||||||
|
"attn.to_out.0": "attn.a_to_out",
|
||||||
|
"attn.add_q_proj": "attn.b_to_q",
|
||||||
|
"attn.add_k_proj": "attn.b_to_k",
|
||||||
|
"attn.add_v_proj": "attn.b_to_v",
|
||||||
|
"attn.to_add_out": "attn.b_to_out",
|
||||||
|
"ff.net.0.proj": "ff_a.0",
|
||||||
|
"ff.net.2": "ff_a.2",
|
||||||
|
"ff_context.net.0.proj": "ff_b.0",
|
||||||
|
"ff_context.net.2": "ff_b.2",
|
||||||
|
"attn.norm_q": "attn.norm_q_a",
|
||||||
|
"attn.norm_k": "attn.norm_k_a",
|
||||||
|
"attn.norm_added_q": "attn.norm_q_b",
|
||||||
|
"attn.norm_added_k": "attn.norm_k_b",
|
||||||
|
}
|
||||||
|
rename_dict_single = {
|
||||||
|
"attn.to_q": "a_to_q",
|
||||||
|
"attn.to_k": "a_to_k",
|
||||||
|
"attn.to_v": "a_to_v",
|
||||||
|
"attn.norm_q": "norm_q_a",
|
||||||
|
"attn.norm_k": "norm_k_a",
|
||||||
|
"norm.linear": "norm.linear",
|
||||||
|
"proj_mlp": "proj_in_besides_attn",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.endswith(".weight") or name.endswith(".bias"):
|
||||||
|
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||||
|
prefix = name[:-len(suffix)]
|
||||||
|
if prefix in global_rename_dict:
|
||||||
|
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||||
|
elif prefix.startswith("transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif prefix.startswith("single_transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "single_blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict_single:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
|
state_dict_[name_] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||||
|
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||||
|
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
||||||
|
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
||||||
|
"txt_in.bias": "context_embedder.bias",
|
||||||
|
"txt_in.weight": "context_embedder.weight",
|
||||||
|
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
||||||
|
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
||||||
|
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
||||||
|
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
||||||
|
"final_layer.linear.bias": "final_proj_out.bias",
|
||||||
|
"final_layer.linear.weight": "final_proj_out.weight",
|
||||||
|
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
||||||
|
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
||||||
|
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
||||||
|
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
||||||
|
"img_in.bias": "x_embedder.bias",
|
||||||
|
"img_in.weight": "x_embedder.weight",
|
||||||
|
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
||||||
|
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
||||||
|
}
|
||||||
|
suffix_rename_dict = {
|
||||||
|
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
||||||
|
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
||||||
|
"img_attn.proj.bias": "attn.a_to_out.bias",
|
||||||
|
"img_attn.proj.weight": "attn.a_to_out.weight",
|
||||||
|
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
||||||
|
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
||||||
|
"img_mlp.0.bias": "ff_a.0.bias",
|
||||||
|
"img_mlp.0.weight": "ff_a.0.weight",
|
||||||
|
"img_mlp.2.bias": "ff_a.2.bias",
|
||||||
|
"img_mlp.2.weight": "ff_a.2.weight",
|
||||||
|
"img_mod.lin.bias": "norm1_a.linear.bias",
|
||||||
|
"img_mod.lin.weight": "norm1_a.linear.weight",
|
||||||
|
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
||||||
|
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
||||||
|
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
||||||
|
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
||||||
|
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
||||||
|
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
||||||
|
"txt_mlp.0.bias": "ff_b.0.bias",
|
||||||
|
"txt_mlp.0.weight": "ff_b.0.weight",
|
||||||
|
"txt_mlp.2.bias": "ff_b.2.bias",
|
||||||
|
"txt_mlp.2.weight": "ff_b.2.weight",
|
||||||
|
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
||||||
|
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
||||||
|
|
||||||
|
"linear1.bias": "to_qkv_mlp.bias",
|
||||||
|
"linear1.weight": "to_qkv_mlp.weight",
|
||||||
|
"linear2.bias": "proj_out.bias",
|
||||||
|
"linear2.weight": "proj_out.weight",
|
||||||
|
"modulation.lin.bias": "norm.linear.bias",
|
||||||
|
"modulation.lin.weight": "norm.linear.weight",
|
||||||
|
"norm.key_norm.scale": "norm_k_a.weight",
|
||||||
|
"norm.query_norm.scale": "norm_q_a.weight",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("model.diffusion_model."):
|
||||||
|
name = name[len("model.diffusion_model."):]
|
||||||
|
names = name.split(".")
|
||||||
|
if name in rename_dict:
|
||||||
|
rename = rename_dict[name]
|
||||||
|
if name.startswith("final_layer.adaLN_modulation.1."):
|
||||||
|
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||||
|
state_dict_[rename] = param
|
||||||
|
elif names[0] == "double_blocks":
|
||||||
|
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||||
|
state_dict_[rename] = param
|
||||||
|
elif names[0] == "single_blocks":
|
||||||
|
if ".".join(names[2:]) in suffix_rename_dict:
|
||||||
|
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||||
|
state_dict_[rename] = param
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||||
|
return state_dict_, {"disable_guidance_embedder": True}
|
||||||
|
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
|
||||||
|
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||||
|
else:
|
||||||
|
return state_dict_
|
||||||
128
diffsynth/models/flux_infiniteyou.py
Normal file
128
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
def FeedForward(dim, mult=4):
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_tensor(x, heads):
|
||||||
|
bs, length, width = x.shape
|
||||||
|
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
|
x = x.view(bs, length, heads, -1)
|
||||||
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||||
|
x = x.reshape(bs, heads, length, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, latents):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): image features
|
||||||
|
shape (b, n1, D)
|
||||||
|
latent (torch.Tensor): latent features
|
||||||
|
shape (b, n2, D)
|
||||||
|
"""
|
||||||
|
x = self.norm1(x)
|
||||||
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
|
q = self.to_q(latents)
|
||||||
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
q = reshape_tensor(q, self.heads)
|
||||||
|
k = reshape_tensor(k, self.heads)
|
||||||
|
v = reshape_tensor(v, self.heads)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
out = weight @ v
|
||||||
|
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteYouImageProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=1280,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=20,
|
||||||
|
num_queries=8,
|
||||||
|
embedding_dim=512,
|
||||||
|
output_dim=4096,
|
||||||
|
ff_mult=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||||
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim)
|
||||||
|
self.norm_out = nn.LayerNorm(output_dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(x, latents) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
return self.norm_out(latents)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict['image_proj']
|
||||||
94
diffsynth/models/flux_ipadapter.py
Normal file
94
diffsynth/models/flux_ipadapter.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from .svd_image_encoder import SVDImageEncoder
|
||||||
|
from .sd3_dit import RMSNorm
|
||||||
|
from transformers import CLIPImageProcessor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MLPProjModel(torch.nn.Module):
|
||||||
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
||||||
|
)
|
||||||
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
def forward(self, id_embeds):
|
||||||
|
x = self.proj(id_embeds)
|
||||||
|
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class IpAdapterModule(torch.nn.Module):
|
||||||
|
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = attention_head_dim
|
||||||
|
output_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
# ip_k
|
||||||
|
ip_k = self.to_k_ip(hidden_states)
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_k = self.norm_added_k(ip_k)
|
||||||
|
# ip_v
|
||||||
|
ip_v = self.to_v_ip(hidden_states)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
return ip_k, ip_v
|
||||||
|
|
||||||
|
|
||||||
|
class FluxIpAdapter(torch.nn.Module):
|
||||||
|
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
||||||
|
super().__init__()
|
||||||
|
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
||||||
|
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
||||||
|
self.set_adapter()
|
||||||
|
|
||||||
|
def set_adapter(self):
|
||||||
|
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
||||||
|
|
||||||
|
def forward(self, hidden_states, scale=1.0):
|
||||||
|
hidden_states = self.image_proj(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||||
|
ip_kv_dict = {}
|
||||||
|
for block_id in self.call_block_id:
|
||||||
|
ipadapter_id = self.call_block_id[block_id]
|
||||||
|
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||||
|
ip_kv_dict[block_id] = {
|
||||||
|
"ip_k": ip_k,
|
||||||
|
"ip_v": ip_v,
|
||||||
|
"scale": scale
|
||||||
|
}
|
||||||
|
return ip_kv_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxIpAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict["ip_adapter"]:
|
||||||
|
name_ = 'ipadapter_modules.' + name
|
||||||
|
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||||
|
for name in state_dict["image_proj"]:
|
||||||
|
name_ = "image_proj." + name
|
||||||
|
state_dict_[name_] = state_dict["image_proj"][name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
32
diffsynth/models/flux_text_encoder.py
Normal file
32
diffsynth/models/flux_text_encoder.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import T5EncoderModel, T5Config
|
||||||
|
from .sd_text_encoder import SDTextEncoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoder2(T5EncoderModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
outputs = super().forward(input_ids=input_ids)
|
||||||
|
prompt_emb = outputs.last_hidden_state
|
||||||
|
return prompt_emb
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxTextEncoder2StateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoder2StateDictConverter():
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
state_dict_ = state_dict
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
303
diffsynth/models/flux_vae.py
Normal file
303
diffsynth/models/flux_vae.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
||||||
|
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEEncoder(SD3VAEEncoder):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.3611
|
||||||
|
self.shift_factor = 0.1159
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEDecoder(SD3VAEDecoder):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.3611
|
||||||
|
self.shift_factor = 0.1159
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"encoder.conv_in.bias": "conv_in.bias",
|
||||||
|
"encoder.conv_in.weight": "conv_in.weight",
|
||||||
|
"encoder.conv_out.bias": "conv_out.bias",
|
||||||
|
"encoder.conv_out.weight": "conv_out.weight",
|
||||||
|
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||||
|
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||||
|
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||||
|
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||||
|
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||||
|
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||||
|
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||||
|
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||||
|
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||||
|
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||||
|
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||||
|
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||||
|
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||||
|
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||||
|
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||||
|
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||||
|
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||||
|
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||||
|
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||||
|
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||||
|
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||||
|
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||||
|
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||||
|
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||||
|
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||||
|
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||||
|
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||||
|
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||||
|
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||||
|
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||||
|
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||||
|
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||||
|
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||||
|
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||||
|
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||||
|
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||||
|
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||||
|
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||||
|
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||||
|
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||||
|
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||||
|
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||||
|
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||||
|
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||||
|
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||||
|
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||||
|
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||||
|
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||||
|
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||||
|
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||||
|
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||||
|
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||||
|
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||||
|
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||||
|
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||||
|
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||||
|
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||||
|
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||||
|
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||||
|
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||||
|
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||||
|
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||||
|
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||||
|
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||||
|
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||||
|
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||||
|
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||||
|
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||||
|
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||||
|
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||||
|
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||||
|
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||||
|
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||||
|
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||||
|
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||||
|
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||||
|
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||||
|
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||||
|
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||||
|
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||||
|
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||||
|
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||||
|
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||||
|
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||||
|
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||||
|
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||||
|
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||||
|
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||||
|
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||||
|
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||||
|
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||||
|
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||||
|
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||||
|
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||||
|
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||||
|
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||||
|
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||||
|
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||||
|
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||||
|
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||||
|
"encoder.norm_out.bias": "conv_norm_out.bias",
|
||||||
|
"encoder.norm_out.weight": "conv_norm_out.weight",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if "transformer_blocks" in rename_dict[name]:
|
||||||
|
param = param.squeeze()
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"decoder.conv_in.bias": "conv_in.bias",
|
||||||
|
"decoder.conv_in.weight": "conv_in.weight",
|
||||||
|
"decoder.conv_out.bias": "conv_out.bias",
|
||||||
|
"decoder.conv_out.weight": "conv_out.weight",
|
||||||
|
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||||
|
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||||
|
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||||
|
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||||
|
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||||
|
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||||
|
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||||
|
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||||
|
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||||
|
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||||
|
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||||
|
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||||
|
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||||
|
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||||
|
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||||
|
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||||
|
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||||
|
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||||
|
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||||
|
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||||
|
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||||
|
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||||
|
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||||
|
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||||
|
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||||
|
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||||
|
"decoder.norm_out.bias": "conv_norm_out.bias",
|
||||||
|
"decoder.norm_out.weight": "conv_norm_out.weight",
|
||||||
|
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||||
|
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||||
|
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||||
|
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||||
|
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||||
|
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||||
|
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||||
|
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||||
|
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||||
|
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||||
|
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||||
|
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||||
|
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||||
|
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||||
|
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||||
|
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||||
|
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||||
|
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||||
|
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||||
|
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||||
|
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||||
|
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||||
|
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||||
|
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||||
|
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||||
|
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||||
|
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||||
|
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||||
|
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||||
|
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||||
|
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||||
|
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||||
|
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||||
|
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||||
|
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||||
|
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||||
|
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||||
|
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||||
|
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||||
|
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||||
|
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||||
|
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||||
|
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||||
|
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||||
|
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||||
|
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||||
|
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||||
|
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||||
|
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||||
|
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||||
|
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||||
|
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||||
|
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||||
|
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||||
|
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||||
|
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||||
|
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||||
|
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||||
|
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||||
|
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||||
|
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||||
|
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||||
|
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||||
|
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||||
|
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||||
|
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||||
|
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||||
|
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||||
|
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||||
|
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||||
|
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||||
|
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||||
|
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||||
|
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||||
|
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||||
|
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||||
|
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||||
|
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||||
|
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||||
|
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||||
|
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||||
|
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||||
|
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||||
|
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||||
|
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||||
|
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||||
|
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||||
|
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||||
|
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||||
|
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||||
|
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||||
|
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||||
|
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||||
|
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||||
|
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||||
|
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||||
|
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||||
|
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||||
|
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||||
|
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||||
|
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||||
|
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||||
|
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||||
|
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||||
|
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||||
|
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if "transformer_blocks" in rename_dict[name]:
|
||||||
|
param = param.squeeze()
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
58
diffsynth/models/flux_value_control.py
Normal file
58
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||||
|
|
||||||
|
|
||||||
|
class MultiValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, encoders=()):
|
||||||
|
super().__init__()
|
||||||
|
self.encoders = torch.nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
def __call__(self, values, dtype):
|
||||||
|
emb = []
|
||||||
|
for encoder, value in zip(self.encoders, values):
|
||||||
|
if value is not None:
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
emb.append(encoder(value, dtype))
|
||||||
|
emb = torch.concat(emb, dim=0)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.prefer_len = prefer_len
|
||||||
|
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||||
|
self.prefer_value_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(
|
||||||
|
torch.randn(self.prefer_len, dim_in)
|
||||||
|
)
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
last_linear = self.prefer_value_embedder[-1]
|
||||||
|
torch.nn.init.zeros_(last_linear.weight)
|
||||||
|
torch.nn.init.zeros_(last_linear.bias)
|
||||||
|
|
||||||
|
def forward(self, value, dtype):
|
||||||
|
emb = self.prefer_proj(value).to(dtype)
|
||||||
|
emb = emb.expand(self.prefer_len, -1)
|
||||||
|
emb = emb + self.positional_embedding
|
||||||
|
emb = self.prefer_value_embedder(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SingleValueEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
from .attention import Attention
|
from .attention import Attention
|
||||||
from .tiler import TileWorker
|
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@@ -399,7 +398,8 @@ class HunyuanDiT(torch.nn.Module):
|
|||||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return HunyuanDiTStateDictConverter()
|
return HunyuanDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
|
|||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -131,7 +132,8 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
|||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
920
diffsynth/models/hunyuan_video_dit.py
Normal file
920
diffsynth/models/hunyuan_video_dit.py
Normal file
@@ -0,0 +1,920 @@
|
|||||||
|
import torch
|
||||||
|
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||||
|
from .utils import init_weights_on_device
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Union, Tuple, List
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
|
||||||
|
|
||||||
|
def HunyuanVideoRope(latents):
|
||||||
|
def _to_tuple(x, dim=2):
|
||||||
|
if isinstance(x, int):
|
||||||
|
return (x,) * dim
|
||||||
|
elif len(x) == dim:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_meshgrid_nd(start, *args, dim=2):
|
||||||
|
"""
|
||||||
|
Get n-D meshgrid with start, stop and num.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||||
|
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||||
|
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||||
|
n-tuples.
|
||||||
|
*args: See above.
|
||||||
|
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
grid (np.ndarray): [dim, ...]
|
||||||
|
"""
|
||||||
|
if len(args) == 0:
|
||||||
|
# start is grid_size
|
||||||
|
num = _to_tuple(start, dim=dim)
|
||||||
|
start = (0,) * dim
|
||||||
|
stop = num
|
||||||
|
elif len(args) == 1:
|
||||||
|
# start is start, args[0] is stop, step is 1
|
||||||
|
start = _to_tuple(start, dim=dim)
|
||||||
|
stop = _to_tuple(args[0], dim=dim)
|
||||||
|
num = [stop[i] - start[i] for i in range(dim)]
|
||||||
|
elif len(args) == 2:
|
||||||
|
# start is start, args[0] is stop, args[1] is num
|
||||||
|
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||||
|
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||||
|
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||||
|
else:
|
||||||
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||||
|
|
||||||
|
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||||
|
axis_grid = []
|
||||||
|
for i in range(dim):
|
||||||
|
a, b, n = start[i], stop[i], num[i]
|
||||||
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||||
|
axis_grid.append(g)
|
||||||
|
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||||
|
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||||
|
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_rotary_pos_embed(
|
||||||
|
dim: int,
|
||||||
|
pos: Union[torch.FloatTensor, int],
|
||||||
|
theta: float = 10000.0,
|
||||||
|
use_real: bool = False,
|
||||||
|
theta_rescale_factor: float = 1.0,
|
||||||
|
interpolation_factor: float = 1.0,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||||
|
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||||
|
|
||||||
|
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||||
|
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||||
|
The returned tensor contains complex values in complex64 data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Dimension of the frequency tensor.
|
||||||
|
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||||
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||||
|
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||||
|
Otherwise, return complex numbers.
|
||||||
|
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||||
|
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||||
|
"""
|
||||||
|
if isinstance(pos, int):
|
||||||
|
pos = torch.arange(pos).float()
|
||||||
|
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
if theta_rescale_factor != 1.0:
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
|
freqs = 1.0 / (
|
||||||
|
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||||
|
) # [D/2]
|
||||||
|
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||||
|
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||||
|
if use_real:
|
||||||
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
else:
|
||||||
|
freqs_cis = torch.polar(
|
||||||
|
torch.ones_like(freqs), freqs
|
||||||
|
) # complex64 # [S, D/2]
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def get_nd_rotary_pos_embed(
|
||||||
|
rope_dim_list,
|
||||||
|
start,
|
||||||
|
*args,
|
||||||
|
theta=10000.0,
|
||||||
|
use_real=False,
|
||||||
|
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||||
|
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||||
|
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||||
|
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||||
|
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||||
|
*args: See above.
|
||||||
|
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||||
|
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||||
|
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||||
|
part and an imaginary part separately.
|
||||||
|
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pos_embed (torch.Tensor): [HW, D/2]
|
||||||
|
"""
|
||||||
|
|
||||||
|
grid = get_meshgrid_nd(
|
||||||
|
start, *args, dim=len(rope_dim_list)
|
||||||
|
) # [3, W, H, D] / [2, W, H]
|
||||||
|
|
||||||
|
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||||
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||||
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||||
|
assert len(theta_rescale_factor) == len(
|
||||||
|
rope_dim_list
|
||||||
|
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||||
|
|
||||||
|
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||||
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||||
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||||
|
assert len(interpolation_factor) == len(
|
||||||
|
rope_dim_list
|
||||||
|
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||||
|
|
||||||
|
# use 1/ndim of dimensions to encode grid_axis
|
||||||
|
embs = []
|
||||||
|
for i in range(len(rope_dim_list)):
|
||||||
|
emb = get_1d_rotary_pos_embed(
|
||||||
|
rope_dim_list[i],
|
||||||
|
grid[i].reshape(-1),
|
||||||
|
theta,
|
||||||
|
use_real=use_real,
|
||||||
|
theta_rescale_factor=theta_rescale_factor[i],
|
||||||
|
interpolation_factor=interpolation_factor[i],
|
||||||
|
) # 2 x [WHD, rope_dim_list[i]]
|
||||||
|
embs.append(emb)
|
||||||
|
|
||||||
|
if use_real:
|
||||||
|
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||||
|
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||||
|
return cos, sin
|
||||||
|
else:
|
||||||
|
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||||
|
[16, 56, 56],
|
||||||
|
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||||
|
theta=256,
|
||||||
|
use_real=True,
|
||||||
|
theta_rescale_factor=1,
|
||||||
|
)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(torch.nn.Module):
|
||||||
|
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=3072, num_heads=24):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||||
|
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
|
||||||
|
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.mlp = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size * 4),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(hidden_size * 4, hidden_size)
|
||||||
|
)
|
||||||
|
self.adaLN_modulation = torch.nn.Sequential(
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c, attn_mask=None):
|
||||||
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
qkv = self.self_attn_qkv(norm_x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
|
||||||
|
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
attn = rearrange(attn, "B H L D -> B L (H D)")
|
||||||
|
|
||||||
|
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
||||||
|
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleTokenRefiner(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||||
|
super().__init__()
|
||||||
|
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
||||||
|
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||||
|
self.c_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(in_channels, hidden_size),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
)
|
||||||
|
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
||||||
|
|
||||||
|
def forward(self, x, t, mask=None):
|
||||||
|
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
||||||
|
|
||||||
|
mask_float = mask.float().unsqueeze(-1)
|
||||||
|
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||||
|
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||||
|
c = timestep_aware_representations + context_aware_representations
|
||||||
|
|
||||||
|
x = self.input_embedder(x)
|
||||||
|
|
||||||
|
mask = mask.to(device=x.device, dtype=torch.bool)
|
||||||
|
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
||||||
|
mask = mask & mask.transpose(2, 3)
|
||||||
|
mask[:, :, :, 0] = True
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, c, mask)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ModulateDiT(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, factor=6):
|
||||||
|
super().__init__()
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(self.act(x))
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||||
|
if tr_shift is not None:
|
||||||
|
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||||
|
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
x = torch.concat((x_zero, x_orig), dim=1)
|
||||||
|
return x
|
||||||
|
if scale is None and shift is None:
|
||||||
|
return x
|
||||||
|
elif shift is None:
|
||||||
|
return x * (1 + scale.unsqueeze(1))
|
||||||
|
elif scale is None:
|
||||||
|
return x + shift.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(
|
||||||
|
freqs_cis,
|
||||||
|
x: torch.Tensor,
|
||||||
|
head_first=False,
|
||||||
|
):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
|
||||||
|
if isinstance(freqs_cis, tuple):
|
||||||
|
# freqs_cis: (cos, sin) in real space
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis[0].shape == (
|
||||||
|
x.shape[-2],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||||
|
shape = [
|
||||||
|
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||||
|
for i, d in enumerate(x.shape)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert freqs_cis[0].shape == (
|
||||||
|
x.shape[1],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||||
|
else:
|
||||||
|
# freqs_cis: values in complex space
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis.shape == (
|
||||||
|
x.shape[-2],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||||
|
shape = [
|
||||||
|
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||||
|
for i, d in enumerate(x.shape)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert freqs_cis.shape == (
|
||||||
|
x.shape[1],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x_real, x_imag = (
|
||||||
|
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||||
|
) # [B, S, H, D//2]
|
||||||
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
freqs_cis,
|
||||||
|
head_first: bool = False,
|
||||||
|
):
|
||||||
|
xk_out = None
|
||||||
|
if isinstance(freqs_cis, tuple):
|
||||||
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||||
|
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||||
|
# real * cos - imag * sin
|
||||||
|
# imag * cos + real * sin
|
||||||
|
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||||
|
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||||
|
else:
|
||||||
|
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
||||||
|
xq_ = torch.view_as_complex(
|
||||||
|
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
||||||
|
) # [B, S, H, D//2]
|
||||||
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
||||||
|
xq.device
|
||||||
|
) # [S, D//2] --> [1, S, 1, D//2]
|
||||||
|
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
||||||
|
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||||
|
xk_ = torch.view_as_complex(
|
||||||
|
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
||||||
|
) # [B, S, H, D//2]
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||||
|
|
||||||
|
return xq_out, xk_out
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v):
|
||||||
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = x.transpose(1, 2).flatten(2, 3)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||||
|
if tr_gate is not None:
|
||||||
|
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||||
|
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||||
|
return torch.concat((x_zero, x_orig), dim=1)
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
|
super().__init__()
|
||||||
|
self.heads_num = heads_num
|
||||||
|
|
||||||
|
self.mod = ModulateDiT(hidden_size)
|
||||||
|
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||||
|
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||||
|
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||||
|
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
|
||||||
|
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||||
|
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||||
|
if token_replace_vec is not None:
|
||||||
|
assert tr_token is not None
|
||||||
|
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||||
|
else:
|
||||||
|
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||||
|
|
||||||
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||||
|
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||||
|
qkv = self.to_qkv(norm_hidden_states)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
|
||||||
|
q = self.norm_q(q)
|
||||||
|
k = self.norm_k(k)
|
||||||
|
|
||||||
|
if freqs_cis is not None:
|
||||||
|
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||||
|
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||||
|
|
||||||
|
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||||
|
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||||
|
if mod_tr is not None:
|
||||||
|
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||||
|
else:
|
||||||
|
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||||
|
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||||
|
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||||
|
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MMDoubleStreamBlock(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
|
super().__init__()
|
||||||
|
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||||
|
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||||
|
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||||
|
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||||
|
|
||||||
|
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||||
|
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||||
|
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||||
|
attn_output_a = attention(q_a, k_a, v_a)
|
||||||
|
attn_output_b = attention(q_b, k_b, v_b)
|
||||||
|
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||||
|
|
||||||
|
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||||
|
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.heads_num = heads_num
|
||||||
|
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
||||||
|
|
||||||
|
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||||
|
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||||
|
|
||||||
|
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||||
|
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||||
|
|
||||||
|
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
||||||
|
self.modulation = ModulateDiT(hidden_size, factor=3)
|
||||||
|
|
||||||
|
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
||||||
|
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||||
|
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||||
|
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||||
|
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||||
|
q = torch.cat((q_a, q_b), dim=1)
|
||||||
|
k = torch.cat((k_a, k_b), dim=1)
|
||||||
|
|
||||||
|
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
||||||
|
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
||||||
|
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||||
|
|
||||||
|
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||||
|
return x + output * mod_gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class MMSingleStreamBlock(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
|
super().__init__()
|
||||||
|
self.heads_num = heads_num
|
||||||
|
|
||||||
|
self.mod = ModulateDiT(hidden_size, factor=3)
|
||||||
|
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||||
|
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||||
|
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||||
|
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||||
|
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||||
|
if token_replace_vec is not None:
|
||||||
|
assert tr_token is not None
|
||||||
|
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||||
|
else:
|
||||||
|
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||||
|
|
||||||
|
norm_hidden_states = self.norm(hidden_states)
|
||||||
|
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||||
|
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||||
|
qkv = self.to_qkv(norm_hidden_states)
|
||||||
|
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
|
||||||
|
q = self.norm_q(q)
|
||||||
|
k = self.norm_k(k)
|
||||||
|
|
||||||
|
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||||
|
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||||
|
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||||
|
|
||||||
|
v_len = txt_len - split_token
|
||||||
|
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||||
|
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||||
|
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||||
|
|
||||||
|
attn_output_a = attention(q_a, k_a, v_a)
|
||||||
|
attn_output_b = attention(q_b, k_b, v_b)
|
||||||
|
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||||
|
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
||||||
|
|
||||||
|
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoDiT(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||||
|
super().__init__()
|
||||||
|
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||||
|
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||||
|
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||||
|
self.vector_in = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(768, hidden_size),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
)
|
||||||
|
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||||
|
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||||
|
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||||
|
self.final_layer = FinalLayer(hidden_size)
|
||||||
|
|
||||||
|
# TODO: remove these parameters
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
self.patch_size = [1, 2, 2]
|
||||||
|
self.hidden_size = 3072
|
||||||
|
self.heads_num = 24
|
||||||
|
self.rope_dim_list = [16, 56, 56]
|
||||||
|
|
||||||
|
def unpatchify(self, x, T, H, W):
|
||||||
|
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||||
|
self.warm_device = warm_device
|
||||||
|
self.cold_device = cold_device
|
||||||
|
self.to(self.cold_device)
|
||||||
|
|
||||||
|
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
||||||
|
for model_name in loadmodel_names:
|
||||||
|
model = getattr(self, model_name)
|
||||||
|
if model is not None:
|
||||||
|
model.to(device)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def prepare_freqs(self, latents):
|
||||||
|
return HunyuanVideoRope(latents)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
prompt_emb: torch.Tensor = None,
|
||||||
|
text_mask: torch.Tensor = None,
|
||||||
|
pooled_prompt_emb: torch.Tensor = None,
|
||||||
|
freqs_cos: torch.Tensor = None,
|
||||||
|
freqs_sin: torch.Tensor = None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
|
||||||
|
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||||
|
if self.guidance_in is not None:
|
||||||
|
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||||
|
img = self.img_in(x)
|
||||||
|
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||||
|
|
||||||
|
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||||
|
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||||
|
|
||||||
|
x = torch.concat([img, txt], dim=1)
|
||||||
|
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
||||||
|
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||||
|
|
||||||
|
img = x[:, :-256]
|
||||||
|
img = self.final_layer(img, vec)
|
||||||
|
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||||
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def cast_weight(s, input=None, dtype=None, device=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
class quantized_layer:
|
||||||
|
class Linear(torch.nn.Linear):
|
||||||
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def block_forward_(self, x, i, j, dtype, device):
|
||||||
|
weight_ = cast_to(
|
||||||
|
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
||||||
|
dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if self.bias is None or i > 0:
|
||||||
|
bias_ = None
|
||||||
|
else:
|
||||||
|
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
||||||
|
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
||||||
|
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
||||||
|
del x_, weight_, bias_
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return y_
|
||||||
|
|
||||||
|
def block_forward(self, x, **kwargs):
|
||||||
|
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||||
|
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||||
|
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
||||||
|
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||||
|
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||||
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(self, hidden_states, **kwargs):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype)
|
||||||
|
if self.module.weight is not None:
|
||||||
|
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||||
|
hidden_states = hidden_states * weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class Conv3d(torch.nn.Conv3d):
|
||||||
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||||
|
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.weight is not None and self.bias is not None:
|
||||||
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||||
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
with init_weights_on_device():
|
||||||
|
new_layer = quantized_layer.Linear(
|
||||||
|
module.in_features, module.out_features, bias=module.bias is not None,
|
||||||
|
dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module, torch.nn.Conv3d):
|
||||||
|
with init_weights_on_device():
|
||||||
|
new_layer = quantized_layer.Conv3d(
|
||||||
|
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
||||||
|
dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
new_layer = quantized_layer.RMSNorm(
|
||||||
|
module,
|
||||||
|
dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module, torch.nn.LayerNorm):
|
||||||
|
with init_weights_on_device():
|
||||||
|
new_layer = quantized_layer.LayerNorm(
|
||||||
|
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
||||||
|
dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
else:
|
||||||
|
replace_layer(module, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
replace_layer(self, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return HunyuanVideoDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoDiTStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
|
if "module" in state_dict:
|
||||||
|
state_dict = state_dict["module"]
|
||||||
|
direct_dict = {
|
||||||
|
"img_in.proj": "img_in.proj",
|
||||||
|
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
||||||
|
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
||||||
|
"vector_in.in_layer": "vector_in.0",
|
||||||
|
"vector_in.out_layer": "vector_in.2",
|
||||||
|
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
||||||
|
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
||||||
|
"txt_in.input_embedder": "txt_in.input_embedder",
|
||||||
|
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
||||||
|
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
||||||
|
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
||||||
|
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
||||||
|
"final_layer.linear": "final_layer.linear",
|
||||||
|
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
||||||
|
}
|
||||||
|
txt_suffix_dict = {
|
||||||
|
"norm1": "norm1",
|
||||||
|
"self_attn_qkv": "self_attn_qkv",
|
||||||
|
"self_attn_proj": "self_attn_proj",
|
||||||
|
"norm2": "norm2",
|
||||||
|
"mlp.fc1": "mlp.0",
|
||||||
|
"mlp.fc2": "mlp.2",
|
||||||
|
"adaLN_modulation.1": "adaLN_modulation.1",
|
||||||
|
}
|
||||||
|
double_suffix_dict = {
|
||||||
|
"img_mod.linear": "component_a.mod.linear",
|
||||||
|
"img_attn_qkv": "component_a.to_qkv",
|
||||||
|
"img_attn_q_norm": "component_a.norm_q",
|
||||||
|
"img_attn_k_norm": "component_a.norm_k",
|
||||||
|
"img_attn_proj": "component_a.to_out",
|
||||||
|
"img_mlp.fc1": "component_a.ff.0",
|
||||||
|
"img_mlp.fc2": "component_a.ff.2",
|
||||||
|
"txt_mod.linear": "component_b.mod.linear",
|
||||||
|
"txt_attn_qkv": "component_b.to_qkv",
|
||||||
|
"txt_attn_q_norm": "component_b.norm_q",
|
||||||
|
"txt_attn_k_norm": "component_b.norm_k",
|
||||||
|
"txt_attn_proj": "component_b.to_out",
|
||||||
|
"txt_mlp.fc1": "component_b.ff.0",
|
||||||
|
"txt_mlp.fc2": "component_b.ff.2",
|
||||||
|
}
|
||||||
|
single_suffix_dict = {
|
||||||
|
"linear1": ["to_qkv", "ff.0"],
|
||||||
|
"linear2": ["to_out", "ff.2"],
|
||||||
|
"q_norm": "norm_q",
|
||||||
|
"k_norm": "norm_k",
|
||||||
|
"modulation.linear": "mod.linear",
|
||||||
|
}
|
||||||
|
# single_suffix_dict = {
|
||||||
|
# "linear1": "linear1",
|
||||||
|
# "linear2": "linear2",
|
||||||
|
# "q_norm": "q_norm",
|
||||||
|
# "k_norm": "k_norm",
|
||||||
|
# "modulation.linear": "modulation.linear",
|
||||||
|
# }
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
names = name.split(".")
|
||||||
|
direct_name = ".".join(names[:-1])
|
||||||
|
if direct_name in direct_dict:
|
||||||
|
name_ = direct_dict[direct_name] + "." + names[-1]
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif names[0] == "double_blocks":
|
||||||
|
prefix = ".".join(names[:2])
|
||||||
|
suffix = ".".join(names[2:-1])
|
||||||
|
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif names[0] == "single_blocks":
|
||||||
|
prefix = ".".join(names[:2])
|
||||||
|
suffix = ".".join(names[2:-1])
|
||||||
|
if isinstance(single_suffix_dict[suffix], list):
|
||||||
|
if suffix == "linear1":
|
||||||
|
name_a, name_b = single_suffix_dict[suffix]
|
||||||
|
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
||||||
|
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||||
|
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||||
|
elif suffix == "linear2":
|
||||||
|
if names[-1] == "weight":
|
||||||
|
name_a, name_b = single_suffix_dict[suffix]
|
||||||
|
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
||||||
|
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||||
|
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||||
|
else:
|
||||||
|
name_a, name_b = single_suffix_dict[suffix]
|
||||||
|
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif names[0] == "txt_in":
|
||||||
|
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
||||||
|
suffix = ".".join(names[4:-1])
|
||||||
|
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
||||||
|
state_dict_[name_] = param
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return state_dict_
|
||||||
68
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
68
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.auto_offload = False
|
||||||
|
|
||||||
|
def enable_auto_offload(self, **kwargs):
|
||||||
|
self.auto_offload = True
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||||
|
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||||
|
inputs_embeds = embed_tokens(input_ids)
|
||||||
|
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||||
|
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
|
if self.auto_offload:
|
||||||
|
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=True,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||||
|
break
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.auto_offload = False
|
||||||
|
|
||||||
|
def enable_auto_offload(self, **kwargs):
|
||||||
|
self.auto_offload = True
|
||||||
|
|
||||||
|
# TODO: implement the low VRAM inference for MLLM.
|
||||||
|
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||||
|
outputs = super().forward(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
pixel_values=pixel_values)
|
||||||
|
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||||
|
return hidden_state
|
||||||
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
@@ -0,0 +1,507 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
||||||
|
) # W, H, T
|
||||||
|
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampleCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
self.conv = None
|
||||||
|
if use_conv:
|
||||||
|
kernel_size = 3 if kernel_size is None else kernel_size
|
||||||
|
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
||||||
|
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||||
|
if hidden_states.shape[0] >= 64:
|
||||||
|
hidden_states = hidden_states.contiguous()
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
||||||
|
if T > 1:
|
||||||
|
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
||||||
|
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
||||||
|
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
||||||
|
|
||||||
|
# If the input is bfloat16, we cast back to bfloat16
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
|
||||||
|
if self.conv:
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlockCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = True
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
||||||
|
|
||||||
|
def forward(self, input_tensor):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
# conv1
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
# conv2
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
# shortcut
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = (self.conv_shortcut(input_tensor))
|
||||||
|
# shortcut and scale
|
||||||
|
output_tensor = input_tensor + hidden_states
|
||||||
|
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
||||||
|
seq_len = n_frame * n_hw
|
||||||
|
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||||
|
for i in range(seq_len):
|
||||||
|
i_frame = i // n_hw
|
||||||
|
mask[i, :(i_frame + 1) * n_hw] = 0
|
||||||
|
if batch_size is not None:
|
||||||
|
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
num_groups=32,
|
||||||
|
dropout=0.0,
|
||||||
|
eps=1e-6,
|
||||||
|
bias=True,
|
||||||
|
residual_connection=True):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.residual_connection = residual_connection
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||||
|
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||||
|
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||||
|
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def forward(self, input_tensor, attn_mask=None):
|
||||||
|
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(hidden_states)
|
||||||
|
v = self.to_v(hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
||||||
|
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
if self.residual_connection:
|
||||||
|
output_tensor = input_tensor + hidden_states
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class UNetMidBlockCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
resnets = [
|
||||||
|
ResnetBlockCausal3D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
groups=num_groups,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
attention_head_dim = attention_head_dim or in_channels
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(
|
||||||
|
Attention(
|
||||||
|
in_channels,
|
||||||
|
num_heads=in_channels // attention_head_dim,
|
||||||
|
head_dim=attention_head_dim,
|
||||||
|
num_groups=num_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
eps=eps,
|
||||||
|
bias=True,
|
||||||
|
residual_connection=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlockCausal3D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
groups=num_groups,
|
||||||
|
eps=eps,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.resnets[0](hidden_states)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
||||||
|
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
||||||
|
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
||||||
|
hidden_states = resnet(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpDecoderBlockCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
eps=1e-6,
|
||||||
|
num_groups=32,
|
||||||
|
add_upsample=True,
|
||||||
|
upsample_scale_factor=(2, 2, 2),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
cur_in_channel = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlockCausal3D(
|
||||||
|
in_channels=cur_in_channel,
|
||||||
|
out_channels=out_channels,
|
||||||
|
groups=num_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
eps=eps,
|
||||||
|
))
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
self.upsamplers = None
|
||||||
|
if add_upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
UpsampleCausal3D(
|
||||||
|
out_channels,
|
||||||
|
use_conv=True,
|
||||||
|
out_channels=out_channels,
|
||||||
|
upsample_factor=upsample_scale_factor,
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states)
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=16,
|
||||||
|
out_channels=3,
|
||||||
|
eps=1e-6,
|
||||||
|
dropout=0.0,
|
||||||
|
block_out_channels=[128, 256, 512, 512],
|
||||||
|
layers_per_block=2,
|
||||||
|
num_groups=32,
|
||||||
|
time_compression_ratio=4,
|
||||||
|
spatial_compression_ratio=8,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_per_block = layers_per_block
|
||||||
|
|
||||||
|
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlockCausal3D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
dropout=dropout,
|
||||||
|
eps=eps,
|
||||||
|
num_groups=num_groups,
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
||||||
|
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
||||||
|
|
||||||
|
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
||||||
|
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
||||||
|
|
||||||
|
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
||||||
|
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
||||||
|
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
||||||
|
|
||||||
|
up_block = UpDecoderBlockCausal3D(
|
||||||
|
in_channels=prev_output_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
eps=eps,
|
||||||
|
num_groups=num_groups,
|
||||||
|
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
||||||
|
upsample_scale_factor=upsample_scale_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
# middle
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.mid_block),
|
||||||
|
hidden_states,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
# up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(up_block),
|
||||||
|
hidden_states,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# middle
|
||||||
|
hidden_states = self.mid_block(hidden_states)
|
||||||
|
# up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
hidden_states = up_block(hidden_states)
|
||||||
|
# post-process
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoVAEDecoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=16,
|
||||||
|
out_channels=3,
|
||||||
|
eps=1e-6,
|
||||||
|
dropout=0.0,
|
||||||
|
block_out_channels=[128, 256, 512, 512],
|
||||||
|
layers_per_block=2,
|
||||||
|
num_groups=32,
|
||||||
|
time_compression_ratio=4,
|
||||||
|
spatial_compression_ratio=8,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = DecoderCausal3D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
eps=eps,
|
||||||
|
dropout=dropout,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
num_groups=num_groups,
|
||||||
|
time_compression_ratio=time_compression_ratio,
|
||||||
|
spatial_compression_ratio=spatial_compression_ratio,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
|
)
|
||||||
|
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.scaling_factor = 0.476986
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, latents):
|
||||||
|
latents = latents / self.scaling_factor
|
||||||
|
latents = self.post_quant_conv(latents)
|
||||||
|
dec = self.decoder(latents)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||||
|
x = torch.ones((length,))
|
||||||
|
if not left_bound:
|
||||||
|
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||||
|
if not right_bound:
|
||||||
|
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(self, data, is_bound, border_width):
|
||||||
|
_, _, T, H, W = data.shape
|
||||||
|
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||||
|
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||||
|
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||||
|
|
||||||
|
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||||
|
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||||
|
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||||
|
|
||||||
|
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||||
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
size_t, size_h, size_w = tile_size
|
||||||
|
stride_t, stride_h, stride_w = tile_stride
|
||||||
|
|
||||||
|
# Split tasks
|
||||||
|
tasks = []
|
||||||
|
for t in range(0, T, stride_t):
|
||||||
|
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||||
|
for h in range(0, H, stride_h):
|
||||||
|
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||||
|
for w in range(0, W, stride_w):
|
||||||
|
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||||
|
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||||
|
tasks.append((t, t_, h, h_, w, w_))
|
||||||
|
|
||||||
|
# Run
|
||||||
|
torch_dtype = self.post_quant_conv.weight.dtype
|
||||||
|
data_device = hidden_states.device
|
||||||
|
computation_device = self.post_quant_conv.weight.device
|
||||||
|
|
||||||
|
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||||
|
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||||
|
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||||
|
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||||
|
if t > 0:
|
||||||
|
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||||
|
|
||||||
|
mask = self.build_mask(
|
||||||
|
hidden_states_batch,
|
||||||
|
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||||
|
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||||
|
).to(dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
target_t = 0 if t==0 else t * 4 + 1
|
||||||
|
target_h = h * 8
|
||||||
|
target_w = w * 8
|
||||||
|
values[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += hidden_states_batch * mask
|
||||||
|
weight[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += mask
|
||||||
|
return values / weight
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||||
|
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||||
|
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return HunyuanVideoVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoVAEDecoderStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
||||||
|
state_dict_[name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
307
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
307
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampleCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DownEncoderBlockCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
eps=1e-6,
|
||||||
|
num_groups=32,
|
||||||
|
add_downsample=True,
|
||||||
|
downsample_stride=2,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
cur_in_channel = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlockCausal3D(
|
||||||
|
in_channels=cur_in_channel,
|
||||||
|
out_channels=out_channels,
|
||||||
|
groups=num_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
eps=eps,
|
||||||
|
))
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
self.downsamplers = None
|
||||||
|
if add_downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
stride=downsample_stride,
|
||||||
|
)])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderCausal3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 16,
|
||||||
|
eps=1e-6,
|
||||||
|
dropout=0.0,
|
||||||
|
block_out_channels=[128, 256, 512, 512],
|
||||||
|
layers_per_block=2,
|
||||||
|
num_groups=32,
|
||||||
|
time_compression_ratio: int = 4,
|
||||||
|
spatial_compression_ratio: int = 8,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
||||||
|
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
||||||
|
|
||||||
|
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
||||||
|
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
||||||
|
|
||||||
|
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
||||||
|
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
||||||
|
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
||||||
|
down_block = DownEncoderBlockCausal3D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
eps=eps,
|
||||||
|
num_groups=num_groups,
|
||||||
|
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
||||||
|
downsample_stride=downsample_stride,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlockCausal3D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
dropout=dropout,
|
||||||
|
eps=eps,
|
||||||
|
num_groups=num_groups,
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
)
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
# down
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(down_block),
|
||||||
|
hidden_states,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
# middle
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.mid_block),
|
||||||
|
hidden_states,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# down
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
hidden_states = down_block(hidden_states)
|
||||||
|
# middle
|
||||||
|
hidden_states = self.mid_block(hidden_states)
|
||||||
|
# post-process
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoVAEEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=16,
|
||||||
|
eps=1e-6,
|
||||||
|
dropout=0.0,
|
||||||
|
block_out_channels=[128, 256, 512, 512],
|
||||||
|
layers_per_block=2,
|
||||||
|
num_groups=32,
|
||||||
|
time_compression_ratio=4,
|
||||||
|
spatial_compression_ratio=8,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = EncoderCausal3D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
eps=eps,
|
||||||
|
dropout=dropout,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
num_groups=num_groups,
|
||||||
|
time_compression_ratio=time_compression_ratio,
|
||||||
|
spatial_compression_ratio=spatial_compression_ratio,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
|
)
|
||||||
|
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||||
|
self.scaling_factor = 0.476986
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
latents = self.encoder(images)
|
||||||
|
latents = self.quant_conv(latents)
|
||||||
|
latents = latents[:, :16]
|
||||||
|
latents = latents * self.scaling_factor
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||||
|
x = torch.ones((length,))
|
||||||
|
if not left_bound:
|
||||||
|
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||||
|
if not right_bound:
|
||||||
|
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(self, data, is_bound, border_width):
|
||||||
|
_, _, T, H, W = data.shape
|
||||||
|
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||||
|
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||||
|
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||||
|
|
||||||
|
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||||
|
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||||
|
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||||
|
|
||||||
|
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||||
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
size_t, size_h, size_w = tile_size
|
||||||
|
stride_t, stride_h, stride_w = tile_stride
|
||||||
|
|
||||||
|
# Split tasks
|
||||||
|
tasks = []
|
||||||
|
for t in range(0, T, stride_t):
|
||||||
|
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||||
|
for h in range(0, H, stride_h):
|
||||||
|
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||||
|
for w in range(0, W, stride_w):
|
||||||
|
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||||
|
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||||
|
tasks.append((t, t_, h, h_, w, w_))
|
||||||
|
|
||||||
|
# Run
|
||||||
|
torch_dtype = self.quant_conv.weight.dtype
|
||||||
|
data_device = hidden_states.device
|
||||||
|
computation_device = self.quant_conv.weight.device
|
||||||
|
|
||||||
|
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||||
|
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||||
|
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||||
|
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||||
|
if t > 0:
|
||||||
|
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||||
|
|
||||||
|
mask = self.build_mask(
|
||||||
|
hidden_states_batch,
|
||||||
|
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||||
|
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
||||||
|
).to(dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
target_t = 0 if t==0 else t // 4 + 1
|
||||||
|
target_h = h // 8
|
||||||
|
target_w = w // 8
|
||||||
|
values[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += hidden_states_batch * mask
|
||||||
|
weight[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += mask
|
||||||
|
return values / weight
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
||||||
|
latents = latents.to(self.quant_conv.weight.dtype)
|
||||||
|
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return HunyuanVideoVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoVAEEncoderStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
||||||
|
state_dict_[name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
1551
diffsynth/models/kolors_text_encoder.py
Normal file
1551
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because one or more lines are too long
386
diffsynth/models/lora.py
Normal file
386
diffsynth/models/lora.py
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_unet import SDUNet
|
||||||
|
from .sdxl_unet import SDXLUNet
|
||||||
|
from .sd_text_encoder import SDTextEncoder
|
||||||
|
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
|
from .sd3_dit import SD3DiT
|
||||||
|
from .flux_dit import FluxDiT
|
||||||
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
from .cog_dit import CogDiT
|
||||||
|
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||||
|
from .wan_video_dit import WanModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAFromCivitai:
|
||||||
|
def __init__(self):
|
||||||
|
self.supported_model_classes = []
|
||||||
|
self.lora_prefix = []
|
||||||
|
self.renamed_lora_prefix = {}
|
||||||
|
self.special_keys = {}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_up" in key:
|
||||||
|
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||||
|
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||||
|
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_up" not in key:
|
||||||
|
continue
|
||||||
|
if not key.startswith(lora_prefix):
|
||||||
|
continue
|
||||||
|
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||||
|
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||||
|
for special_key in self.special_keys:
|
||||||
|
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
if not key.startswith(lora_prefix):
|
||||||
|
continue
|
||||||
|
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||||
|
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
keys = key.split(".")
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
target_name = target_name[len(lora_prefix):]
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||||
|
state_dict_model = model.state_dict()
|
||||||
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||||
|
if model_resource == "diffusers":
|
||||||
|
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||||
|
elif model_resource == "civitai":
|
||||||
|
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||||
|
if isinstance(state_dict_lora, tuple):
|
||||||
|
state_dict_lora = state_dict_lora[0]
|
||||||
|
if len(state_dict_lora) > 0:
|
||||||
|
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||||
|
for name in state_dict_lora:
|
||||||
|
fp8=False
|
||||||
|
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||||
|
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
||||||
|
fp8=True
|
||||||
|
state_dict_model[name] += state_dict_lora[name].to(
|
||||||
|
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||||
|
if fp8:
|
||||||
|
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
||||||
|
model.load_state_dict(state_dict_model)
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, model, state_dict_lora):
|
||||||
|
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||||
|
if not isinstance(model, model_class):
|
||||||
|
continue
|
||||||
|
state_dict_model = model.state_dict()
|
||||||
|
for model_resource in ["diffusers", "civitai"]:
|
||||||
|
try:
|
||||||
|
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||||
|
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||||
|
else model.__class__.state_dict_converter().from_civitai
|
||||||
|
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||||
|
if isinstance(state_dict_lora_, tuple):
|
||||||
|
state_dict_lora_ = state_dict_lora_[0]
|
||||||
|
if len(state_dict_lora_) == 0:
|
||||||
|
continue
|
||||||
|
for name in state_dict_lora_:
|
||||||
|
if name not in state_dict_model:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return lora_prefix, model_resource
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
||||||
|
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
||||||
|
self.special_keys = {
|
||||||
|
"down.blocks": "down_blocks",
|
||||||
|
"up.blocks": "up_blocks",
|
||||||
|
"mid.block": "mid_block",
|
||||||
|
"proj.in": "proj_in",
|
||||||
|
"proj.out": "proj_out",
|
||||||
|
"transformer.blocks": "transformer_blocks",
|
||||||
|
"to.q": "to_q",
|
||||||
|
"to.k": "to_k",
|
||||||
|
"to.v": "to_v",
|
||||||
|
"to.out": "to_out",
|
||||||
|
"text.model": "text_model",
|
||||||
|
"self.attn.q.proj": "self_attn.q_proj",
|
||||||
|
"self.attn.k.proj": "self_attn.k_proj",
|
||||||
|
"self.attn.v.proj": "self_attn.v_proj",
|
||||||
|
"self.attn.out.proj": "self_attn.out_proj",
|
||||||
|
"input.blocks": "model.diffusion_model.input_blocks",
|
||||||
|
"middle.block": "model.diffusion_model.middle_block",
|
||||||
|
"output.blocks": "model.diffusion_model.output_blocks",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
||||||
|
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
||||||
|
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
||||||
|
self.special_keys = {
|
||||||
|
"down.blocks": "down_blocks",
|
||||||
|
"up.blocks": "up_blocks",
|
||||||
|
"mid.block": "mid_block",
|
||||||
|
"proj.in": "proj_in",
|
||||||
|
"proj.out": "proj_out",
|
||||||
|
"transformer.blocks": "transformer_blocks",
|
||||||
|
"to.q": "to_q",
|
||||||
|
"to.k": "to_k",
|
||||||
|
"to.v": "to_v",
|
||||||
|
"to.out": "to_out",
|
||||||
|
"text.model": "conditioner.embedders.0.transformer.text_model",
|
||||||
|
"self.attn.q.proj": "self_attn.q_proj",
|
||||||
|
"self.attn.k.proj": "self_attn.k_proj",
|
||||||
|
"self.attn.v.proj": "self_attn.v_proj",
|
||||||
|
"self.attn.out.proj": "self_attn.out_proj",
|
||||||
|
"input.blocks": "model.diffusion_model.input_blocks",
|
||||||
|
"middle.block": "model.diffusion_model.middle_block",
|
||||||
|
"output.blocks": "model.diffusion_model.output_blocks",
|
||||||
|
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [FluxDiT, FluxDiT]
|
||||||
|
self.lora_prefix = ["lora_unet_", "transformer."]
|
||||||
|
self.renamed_lora_prefix = {}
|
||||||
|
self.special_keys = {
|
||||||
|
"single.blocks": "single_blocks",
|
||||||
|
"double.blocks": "double_blocks",
|
||||||
|
"img.attn": "img_attn",
|
||||||
|
"img.mlp": "img_mlp",
|
||||||
|
"img.mod": "img_mod",
|
||||||
|
"txt.attn": "txt_attn",
|
||||||
|
"txt.mlp": "txt_mlp",
|
||||||
|
"txt.mod": "txt_mod",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralLoRAFromPeft:
|
||||||
|
def __init__(self):
|
||||||
|
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_dict(self, lora_state_dict):
|
||||||
|
lora_name_dict = {}
|
||||||
|
for key in lora_state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
keys = key.split(".")
|
||||||
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||||
|
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||||
|
if matched_num == len(lora_name_dict):
|
||||||
|
return "", ""
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_device_and_dtype(self, state_dict):
|
||||||
|
device, dtype = None, None
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
device, dtype = param.device, param.dtype
|
||||||
|
break
|
||||||
|
computation_device = device
|
||||||
|
computation_dtype = dtype
|
||||||
|
if computation_device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
computation_device = torch.device("cuda")
|
||||||
|
if computation_dtype == torch.float8_e4m3fn:
|
||||||
|
computation_dtype = torch.float32
|
||||||
|
return device, dtype, computation_device, computation_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||||
|
state_dict_model = model.state_dict()
|
||||||
|
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
for name in lora_name_dict:
|
||||||
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||||
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||||
|
weight_patched = weight_model + weight_lora
|
||||||
|
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||||
|
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||||
|
model.load_state_dict(state_dict_model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||||
|
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||||
|
self.special_keys = {}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||||
|
prefix_rename_dict = {
|
||||||
|
"single_blocks": "lora_unet_single_blocks",
|
||||||
|
"blocks": "lora_unet_double_blocks",
|
||||||
|
}
|
||||||
|
middle_rename_dict = {
|
||||||
|
"norm.linear": "modulation_lin",
|
||||||
|
"to_qkv_mlp": "linear1",
|
||||||
|
"proj_out": "linear2",
|
||||||
|
|
||||||
|
"norm1_a.linear": "img_mod_lin",
|
||||||
|
"norm1_b.linear": "txt_mod_lin",
|
||||||
|
"attn.a_to_qkv": "img_attn_qkv",
|
||||||
|
"attn.b_to_qkv": "txt_attn_qkv",
|
||||||
|
"attn.a_to_out": "img_attn_proj",
|
||||||
|
"attn.b_to_out": "txt_attn_proj",
|
||||||
|
"ff_a.0": "img_mlp_0",
|
||||||
|
"ff_a.2": "img_mlp_2",
|
||||||
|
"ff_b.0": "txt_mlp_0",
|
||||||
|
"ff_b.2": "txt_mlp_2",
|
||||||
|
}
|
||||||
|
suffix_rename_dict = {
|
||||||
|
"lora_B.weight": "lora_up.weight",
|
||||||
|
"lora_A.weight": "lora_down.weight",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
names = name.split(".")
|
||||||
|
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
||||||
|
names.pop(-2)
|
||||||
|
prefix = names[0]
|
||||||
|
middle = ".".join(names[2:-2])
|
||||||
|
suffix = ".".join(names[-2:])
|
||||||
|
block_id = names[1]
|
||||||
|
if middle not in middle_rename_dict:
|
||||||
|
continue
|
||||||
|
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||||
|
state_dict_[rename] = param
|
||||||
|
if rename.endswith("lora_up.weight"):
|
||||||
|
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
def guess_block_id(name):
|
||||||
|
names = name.split("_")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"_{i}_", "_blockid_")
|
||||||
|
return None, None
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
block_id, source_name = guess_block_id(name)
|
||||||
|
if source_name in rename_dict:
|
||||||
|
target_name = rename_dict[source_name]
|
||||||
|
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||||
|
state_dict_[target_name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class WanLoRAConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_opensource_format(state_dict, **kwargs):
|
||||||
|
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_loaders():
|
||||||
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
454
diffsynth/models/model_manager.py
Normal file
454
diffsynth/models/model_manager.py
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
import os, torch, json, importlib
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||||
|
|
||||||
|
from .sd_text_encoder import SDTextEncoder
|
||||||
|
from .sd_unet import SDUNet
|
||||||
|
from .sd_vae_encoder import SDVAEEncoder
|
||||||
|
from .sd_vae_decoder import SDVAEDecoder
|
||||||
|
from .lora import get_lora_loaders
|
||||||
|
|
||||||
|
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
|
from .sdxl_unet import SDXLUNet
|
||||||
|
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||||
|
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||||
|
|
||||||
|
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||||
|
from .sd3_dit import SD3DiT
|
||||||
|
from .sd3_vae_decoder import SD3VAEDecoder
|
||||||
|
from .sd3_vae_encoder import SD3VAEEncoder
|
||||||
|
|
||||||
|
from .sd_controlnet import SDControlNet
|
||||||
|
from .sdxl_controlnet import SDXLControlNetUnion
|
||||||
|
|
||||||
|
from .sd_motion import SDMotionModel
|
||||||
|
from .sdxl_motion import SDXLMotionModel
|
||||||
|
|
||||||
|
from .svd_image_encoder import SVDImageEncoder
|
||||||
|
from .svd_unet import SVDUNet
|
||||||
|
from .svd_vae_decoder import SVDVAEDecoder
|
||||||
|
from .svd_vae_encoder import SVDVAEEncoder
|
||||||
|
|
||||||
|
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||||
|
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||||
|
|
||||||
|
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||||
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||||
|
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||||
|
|
||||||
|
from .flux_dit import FluxDiT
|
||||||
|
from .flux_text_encoder import FluxTextEncoder2
|
||||||
|
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
|
from .flux_ipadapter import FluxIpAdapter
|
||||||
|
|
||||||
|
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||||
|
from .cog_dit import CogDiT
|
||||||
|
|
||||||
|
from ..extensions.RIFE import IFNet
|
||||||
|
from ..extensions.ESRGAN import RRDBNet
|
||||||
|
|
||||||
|
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||||
|
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||||
|
loaded_model_names, loaded_models = [], []
|
||||||
|
for model_name, model_class in zip(model_names, model_classes):
|
||||||
|
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||||
|
state_dict_converter = model_class.state_dict_converter()
|
||||||
|
if model_resource == "civitai":
|
||||||
|
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||||
|
elif model_resource == "diffusers":
|
||||||
|
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||||
|
if isinstance(state_dict_results, tuple):
|
||||||
|
model_state_dict, extra_kwargs = state_dict_results
|
||||||
|
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||||
|
else:
|
||||||
|
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||||
|
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||||
|
with init_weights_on_device():
|
||||||
|
model = model_class(**extra_kwargs)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
model.load_state_dict(model_state_dict, assign=True)
|
||||||
|
model = model.to(dtype=torch_dtype, device=device)
|
||||||
|
loaded_model_names.append(model_name)
|
||||||
|
loaded_models.append(model)
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||||
|
loaded_model_names, loaded_models = [], []
|
||||||
|
for model_name, model_class in zip(model_names, model_classes):
|
||||||
|
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||||
|
else:
|
||||||
|
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||||
|
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||||
|
model = model.half()
|
||||||
|
try:
|
||||||
|
model = model.to(device=device)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
loaded_model_names.append(model_name)
|
||||||
|
loaded_models.append(model)
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||||
|
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||||
|
base_state_dict = base_model.state_dict()
|
||||||
|
base_model.to("cpu")
|
||||||
|
del base_model
|
||||||
|
model = model_class(**extra_kwargs)
|
||||||
|
model.load_state_dict(base_state_dict, strict=False)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
model.to(dtype=torch_dtype, device=device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||||
|
loaded_model_names, loaded_models = [], []
|
||||||
|
for model_name, model_class in zip(model_names, model_classes):
|
||||||
|
while True:
|
||||||
|
for model_id in range(len(model_manager.model)):
|
||||||
|
base_model_name = model_manager.model_name[model_id]
|
||||||
|
if base_model_name == model_name:
|
||||||
|
base_model_path = model_manager.model_path[model_id]
|
||||||
|
base_model = model_manager.model[model_id]
|
||||||
|
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||||
|
patched_model = load_single_patch_model_from_single_file(
|
||||||
|
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||||
|
loaded_model_names.append(base_model_name)
|
||||||
|
loaded_models.append(patched_model)
|
||||||
|
model_manager.model.pop(model_id)
|
||||||
|
model_manager.model_path.pop(model_id)
|
||||||
|
model_manager.model_name.pop(model_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDetectorTemplate:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def match(self, file_path="", state_dict={}):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDetectorFromSingleFile:
|
||||||
|
def __init__(self, model_loader_configs=[]):
|
||||||
|
self.keys_hash_with_shape_dict = {}
|
||||||
|
self.keys_hash_dict = {}
|
||||||
|
for metadata in model_loader_configs:
|
||||||
|
self.add_model_metadata(*metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||||
|
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||||
|
if keys_hash is not None:
|
||||||
|
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, file_path="", state_dict={}):
|
||||||
|
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||||
|
return False
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
|
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||||
|
return True
|
||||||
|
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||||
|
if keys_hash in self.keys_hash_dict:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
|
||||||
|
# Load models with strict matching
|
||||||
|
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
|
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||||
|
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||||
|
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
# Load models without strict matching
|
||||||
|
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||||
|
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||||
|
if keys_hash in self.keys_hash_dict:
|
||||||
|
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||||
|
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||||
|
def __init__(self, model_loader_configs=[]):
|
||||||
|
super().__init__(model_loader_configs)
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, file_path="", state_dict={}):
|
||||||
|
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||||
|
return False
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||||
|
for sub_state_dict in splited_state_dict:
|
||||||
|
if super().match(file_path, sub_state_dict):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||||
|
# Split the state_dict and load from each component
|
||||||
|
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||||
|
valid_state_dict = {}
|
||||||
|
for sub_state_dict in splited_state_dict:
|
||||||
|
if super().match(file_path, sub_state_dict):
|
||||||
|
valid_state_dict.update(sub_state_dict)
|
||||||
|
if super().match(file_path, valid_state_dict):
|
||||||
|
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||||
|
else:
|
||||||
|
loaded_model_names, loaded_models = [], []
|
||||||
|
for sub_state_dict in splited_state_dict:
|
||||||
|
if super().match(file_path, sub_state_dict):
|
||||||
|
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||||
|
loaded_model_names += loaded_model_names_
|
||||||
|
loaded_models += loaded_models_
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDetectorFromHuggingfaceFolder:
|
||||||
|
def __init__(self, model_loader_configs=[]):
|
||||||
|
self.architecture_dict = {}
|
||||||
|
for metadata in model_loader_configs:
|
||||||
|
self.add_model_metadata(*metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
||||||
|
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, file_path="", state_dict={}):
|
||||||
|
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
||||||
|
return False
|
||||||
|
file_list = os.listdir(file_path)
|
||||||
|
if "config.json" not in file_list:
|
||||||
|
return False
|
||||||
|
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
if "architectures" not in config and "_class_name" not in config:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||||
|
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
loaded_model_names, loaded_models = [], []
|
||||||
|
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
||||||
|
for architecture in architectures:
|
||||||
|
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
||||||
|
if redirected_architecture is not None:
|
||||||
|
architecture = redirected_architecture
|
||||||
|
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||||
|
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||||
|
loaded_model_names += loaded_model_names_
|
||||||
|
loaded_models += loaded_models_
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDetectorFromPatchedSingleFile:
|
||||||
|
def __init__(self, model_loader_configs=[]):
|
||||||
|
self.keys_hash_with_shape_dict = {}
|
||||||
|
for metadata in model_loader_configs:
|
||||||
|
self.add_model_metadata(*metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||||
|
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, file_path="", state_dict={}):
|
||||||
|
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
||||||
|
return False
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
|
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
|
||||||
|
# Load models with strict matching
|
||||||
|
loaded_model_names, loaded_models = [], []
|
||||||
|
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
|
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||||
|
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||||
|
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||||
|
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||||
|
loaded_model_names += loaded_model_names_
|
||||||
|
loaded_models += loaded_models_
|
||||||
|
return loaded_model_names, loaded_models
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cuda",
|
||||||
|
model_id_list: List[Preset_model_id] = [],
|
||||||
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
|
file_path_list: List[str] = [],
|
||||||
|
):
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.device = device
|
||||||
|
self.model = []
|
||||||
|
self.model_path = []
|
||||||
|
self.model_name = []
|
||||||
|
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
||||||
|
self.model_detector = [
|
||||||
|
ModelDetectorFromSingleFile(model_loader_configs),
|
||||||
|
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||||
|
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||||
|
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||||
|
]
|
||||||
|
self.load_models(downloaded_files + file_path_list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||||
|
print(f"Loading models from file: {file_path}")
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||||
|
for model_name, model in zip(model_names, models):
|
||||||
|
self.model.append(model)
|
||||||
|
self.model_path.append(file_path)
|
||||||
|
self.model_name.append(model_name)
|
||||||
|
print(f" The following models are loaded: {model_names}.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||||
|
print(f"Loading models from folder: {file_path}")
|
||||||
|
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||||
|
for model_name, model in zip(model_names, models):
|
||||||
|
self.model.append(model)
|
||||||
|
self.model_path.append(file_path)
|
||||||
|
self.model_name.append(model_name)
|
||||||
|
print(f" The following models are loaded: {model_names}.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||||
|
print(f"Loading patch models from file: {file_path}")
|
||||||
|
model_names, models = load_patch_model_from_single_file(
|
||||||
|
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||||
|
for model_name, model in zip(model_names, models):
|
||||||
|
self.model.append(model)
|
||||||
|
self.model_path.append(file_path)
|
||||||
|
self.model_name.append(model_name)
|
||||||
|
print(f" The following patched models are loaded: {model_names}.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
for file_path_ in file_path:
|
||||||
|
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||||
|
else:
|
||||||
|
print(f"Loading LoRA models from file: {file_path}")
|
||||||
|
is_loaded = False
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||||
|
for lora in get_lora_loaders():
|
||||||
|
match_results = lora.match(model, state_dict)
|
||||||
|
if match_results is not None:
|
||||||
|
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||||
|
lora_prefix, model_resource = match_results
|
||||||
|
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||||
|
is_loaded = True
|
||||||
|
break
|
||||||
|
if not is_loaded:
|
||||||
|
print(f" Cannot load LoRA: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||||
|
print(f"Loading models from: {file_path}")
|
||||||
|
if device is None: device = self.device
|
||||||
|
if torch_dtype is None: torch_dtype = self.torch_dtype
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for path in file_path:
|
||||||
|
state_dict.update(load_state_dict(path))
|
||||||
|
elif os.path.isfile(file_path):
|
||||||
|
state_dict = load_state_dict(file_path)
|
||||||
|
else:
|
||||||
|
state_dict = None
|
||||||
|
for model_detector in self.model_detector:
|
||||||
|
if model_detector.match(file_path, state_dict):
|
||||||
|
model_names, models = model_detector.load(
|
||||||
|
file_path, state_dict,
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
allowed_model_names=model_names, model_manager=self
|
||||||
|
)
|
||||||
|
for model_name, model in zip(model_names, models):
|
||||||
|
self.model.append(model)
|
||||||
|
self.model_path.append(file_path)
|
||||||
|
self.model_name.append(model_name)
|
||||||
|
print(f" The following models are loaded: {model_names}.")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f" We cannot detect the model type. No models are loaded.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
||||||
|
for file_path in file_path_list:
|
||||||
|
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||||
|
fetched_models = []
|
||||||
|
fetched_model_paths = []
|
||||||
|
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||||
|
if file_path is not None and file_path != model_path:
|
||||||
|
continue
|
||||||
|
if model_name == model_name_:
|
||||||
|
fetched_models.append(model)
|
||||||
|
fetched_model_paths.append(model_path)
|
||||||
|
if len(fetched_models) == 0:
|
||||||
|
print(f"No {model_name} models available.")
|
||||||
|
return None
|
||||||
|
if len(fetched_models) == 1:
|
||||||
|
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||||
|
else:
|
||||||
|
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||||
|
if require_model_path:
|
||||||
|
return fetched_models[0], fetched_model_paths[0]
|
||||||
|
else:
|
||||||
|
return fetched_models[0]
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for model in self.model:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
803
diffsynth/models/omnigen.py
Normal file
803
diffsynth/models/omnigen.py
Normal file
@@ -0,0 +1,803 @@
|
|||||||
|
# The code is revised from DiT
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers import Phi3Config, Phi3Model
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3Transformer(Phi3Model):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
||||||
|
We only modified the attention mask
|
||||||
|
Args:
|
||||||
|
config: Phi3Config
|
||||||
|
"""
|
||||||
|
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
||||||
|
"Starts prefetching the next layer cache"
|
||||||
|
with torch.cuda.stream(self.prefetch_stream):
|
||||||
|
# Prefetch next layer tensors to GPU
|
||||||
|
for name, param in self.layers[layer_idx].named_parameters():
|
||||||
|
param.data = param.data.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
def evict_previous_layer(self, layer_idx: int):
|
||||||
|
"Moves the previous layer cache to the CPU"
|
||||||
|
prev_layer_idx = layer_idx - 1
|
||||||
|
for name, param in self.layers[prev_layer_idx].named_parameters():
|
||||||
|
param.data = param.data.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
||||||
|
# init stream
|
||||||
|
if not hasattr(self, "prefetch_stream"):
|
||||||
|
self.prefetch_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
# delete previous layer
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.evict_previous_layer(layer_idx)
|
||||||
|
|
||||||
|
# make sure the current layer is ready
|
||||||
|
torch.cuda.synchronize(self.prefetch_stream)
|
||||||
|
|
||||||
|
# load next layer
|
||||||
|
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
offload_model: Optional[bool] = False,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if inputs_embeds is None:
|
||||||
|
# inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# if cache_position is None:
|
||||||
|
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
# cache_position = torch.arange(
|
||||||
|
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
# )
|
||||||
|
# if position_ids is None:
|
||||||
|
# position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 3:
|
||||||
|
dtype = inputs_embeds.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
attention_mask = (1 - attention_mask) * min_dtype
|
||||||
|
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
||||||
|
else:
|
||||||
|
raise Exception("attention_mask parameter was unavailable or invalid")
|
||||||
|
# causal_mask = self._update_causal_mask(
|
||||||
|
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
# )
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
layer_idx = -1
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
layer_idx += 1
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if offload_model and not self.training:
|
||||||
|
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
print('************')
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||||
|
).to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, dtype=torch.float32):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, patch_size, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||||
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
if isinstance(grid_size, int):
|
||||||
|
grid_size = (grid_size, grid_size)
|
||||||
|
|
||||||
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
||||||
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
|
||||||
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
|
if cls_token and extra_tokens > 0:
|
||||||
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.
|
||||||
|
omega = 1. / 10000**omega # (D/2,)
|
||||||
|
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbedMR(nn.Module):
|
||||||
|
""" 2D Image to Patch Embedding
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_chans: int = 4,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class OmniGenOriginalModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transformer_config: Phi3Config,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
pe_interpolation: float = 1.0,
|
||||||
|
pos_embed_max_size: int = 192,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
|
||||||
|
hidden_size = transformer_config.hidden_size
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||||
|
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||||
|
|
||||||
|
self.time_token = TimestepEmbedder(hidden_size)
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||||
|
|
||||||
|
self.pe_interpolation = pe_interpolation
|
||||||
|
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
||||||
|
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
||||||
|
|
||||||
|
self.initialize_weights()
|
||||||
|
|
||||||
|
self.llm = Phi3Transformer(config=transformer_config)
|
||||||
|
self.llm.config.use_cache = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name):
|
||||||
|
if not os.path.exists(model_name):
|
||||||
|
cache_folder = os.getenv('HF_HUB_CACHE')
|
||||||
|
model_name = snapshot_download(repo_id=model_name,
|
||||||
|
cache_dir=cache_folder,
|
||||||
|
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
||||||
|
config = Phi3Config.from_pretrained(model_name)
|
||||||
|
model = cls(config)
|
||||||
|
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
||||||
|
print("Loading safetensors")
|
||||||
|
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
||||||
|
model.load_state_dict(ckpt)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
assert not hasattr(self, "llama")
|
||||||
|
|
||||||
|
# Initialize transformer layers:
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
self.apply(_basic_init)
|
||||||
|
|
||||||
|
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||||
|
w = self.x_embedder.proj.weight.data
|
||||||
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||||
|
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||||
|
|
||||||
|
w = self.input_x_embedder.proj.weight.data
|
||||||
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||||
|
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize timestep embedding MLP:
|
||||||
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
||||||
|
|
||||||
|
# Zero-out output layers:
|
||||||
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||||
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||||
|
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||||
|
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||||
|
|
||||||
|
def unpatchify(self, x, h, w):
|
||||||
|
"""
|
||||||
|
x: (N, T, patch_size**2 * C)
|
||||||
|
imgs: (N, H, W, C)
|
||||||
|
"""
|
||||||
|
c = self.out_channels
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
||||||
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
|
||||||
|
def cropped_pos_embed(self, height, width):
|
||||||
|
"""Crops positional embeddings for SD3 compatibility."""
|
||||||
|
if self.pos_embed_max_size is None:
|
||||||
|
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
||||||
|
|
||||||
|
height = height // self.patch_size
|
||||||
|
width = width // self.patch_size
|
||||||
|
if height > self.pos_embed_max_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||||
|
)
|
||||||
|
if width > self.pos_embed_max_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
top = (self.pos_embed_max_size - height) // 2
|
||||||
|
left = (self.pos_embed_max_size - width) // 2
|
||||||
|
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
||||||
|
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
||||||
|
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
||||||
|
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||||
|
return spatial_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
||||||
|
if isinstance(latents, list):
|
||||||
|
return_list = False
|
||||||
|
if padding_latent is None:
|
||||||
|
padding_latent = [None] * len(latents)
|
||||||
|
return_list = True
|
||||||
|
patched_latents, num_tokens, shapes = [], [], []
|
||||||
|
for latent, padding in zip(latents, padding_latent):
|
||||||
|
height, width = latent.shape[-2:]
|
||||||
|
if is_input_images:
|
||||||
|
latent = self.input_x_embedder(latent)
|
||||||
|
else:
|
||||||
|
latent = self.x_embedder(latent)
|
||||||
|
pos_embed = self.cropped_pos_embed(height, width)
|
||||||
|
latent = latent + pos_embed
|
||||||
|
if padding is not None:
|
||||||
|
latent = torch.cat([latent, padding], dim=-2)
|
||||||
|
patched_latents.append(latent)
|
||||||
|
|
||||||
|
num_tokens.append(pos_embed.size(1))
|
||||||
|
shapes.append([height, width])
|
||||||
|
if not return_list:
|
||||||
|
latents = torch.cat(patched_latents, dim=0)
|
||||||
|
else:
|
||||||
|
latents = patched_latents
|
||||||
|
else:
|
||||||
|
height, width = latents.shape[-2:]
|
||||||
|
if is_input_images:
|
||||||
|
latents = self.input_x_embedder(latents)
|
||||||
|
else:
|
||||||
|
latents = self.x_embedder(latents)
|
||||||
|
pos_embed = self.cropped_pos_embed(height, width)
|
||||||
|
latents = latents + pos_embed
|
||||||
|
num_tokens = latents.size(1)
|
||||||
|
shapes = [height, width]
|
||||||
|
return latents, num_tokens, shapes
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
input_is_list = isinstance(x, list)
|
||||||
|
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||||
|
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
if input_img_latents is not None:
|
||||||
|
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||||
|
if input_ids is not None:
|
||||||
|
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||||
|
input_img_inx = 0
|
||||||
|
for b_inx in input_image_sizes.keys():
|
||||||
|
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||||
|
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||||
|
input_img_inx += 1
|
||||||
|
if input_img_latents is not None:
|
||||||
|
assert input_img_inx == len(input_latents)
|
||||||
|
|
||||||
|
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||||
|
else:
|
||||||
|
input_emb = torch.cat([time_token, x], dim=1)
|
||||||
|
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||||
|
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||||
|
if input_is_list:
|
||||||
|
image_embedding = output[:, -max(num_tokens):]
|
||||||
|
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||||
|
x = self.final_layer(image_embedding, time_emb)
|
||||||
|
latents = []
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
latent = x[i:i+1, :num_tokens[i]]
|
||||||
|
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||||
|
latents.append(latent)
|
||||||
|
else:
|
||||||
|
image_embedding = output[:, -num_tokens:]
|
||||||
|
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||||
|
x = self.final_layer(image_embedding, time_emb)
|
||||||
|
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||||
|
|
||||||
|
if return_past_key_values:
|
||||||
|
return latents, past_key_values
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||||
|
self.llm.config.use_cache = use_kv_cache
|
||||||
|
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
||||||
|
if use_img_cfg:
|
||||||
|
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
||||||
|
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||||
|
model_out = [cond, cond, cond]
|
||||||
|
else:
|
||||||
|
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
||||||
|
cond = uncond + cfg_scale * (cond - uncond)
|
||||||
|
model_out = [cond, cond]
|
||||||
|
|
||||||
|
return torch.cat(model_out, dim=0), past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||||
|
self.llm.config.use_cache = use_kv_cache
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = [None] * len(attention_mask)
|
||||||
|
|
||||||
|
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||||
|
timestep = timestep.to(x[0].dtype)
|
||||||
|
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||||
|
|
||||||
|
model_out, pask_key_values = [], []
|
||||||
|
for i in range(len(input_ids)):
|
||||||
|
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||||
|
model_out.append(temp_out)
|
||||||
|
pask_key_values.append(temp_pask_key_values)
|
||||||
|
|
||||||
|
if len(model_out) == 3:
|
||||||
|
cond, uncond, img_cond = model_out
|
||||||
|
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||||
|
model_out = [cond, cond, cond]
|
||||||
|
elif len(model_out) == 2:
|
||||||
|
cond, uncond = model_out
|
||||||
|
cond = uncond + cfg_scale * (cond - uncond)
|
||||||
|
model_out = [cond, cond]
|
||||||
|
else:
|
||||||
|
return model_out[0]
|
||||||
|
|
||||||
|
return torch.cat(model_out, dim=0), pask_key_values
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class OmniGenTransformer(OmniGenOriginalModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = {
|
||||||
|
"_name_or_path": "Phi-3-vision-128k-instruct",
|
||||||
|
"architectures": [
|
||||||
|
"Phi3ForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 8192,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "phi3",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"num_key_value_heads": 32,
|
||||||
|
"original_max_position_embeddings": 4096,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_scaling": {
|
||||||
|
"long_factor": [
|
||||||
|
1.0299999713897705,
|
||||||
|
1.0499999523162842,
|
||||||
|
1.0499999523162842,
|
||||||
|
1.0799999237060547,
|
||||||
|
1.2299998998641968,
|
||||||
|
1.2299998998641968,
|
||||||
|
1.2999999523162842,
|
||||||
|
1.4499999284744263,
|
||||||
|
1.5999999046325684,
|
||||||
|
1.6499998569488525,
|
||||||
|
1.8999998569488525,
|
||||||
|
2.859999895095825,
|
||||||
|
3.68999981880188,
|
||||||
|
5.419999599456787,
|
||||||
|
5.489999771118164,
|
||||||
|
5.489999771118164,
|
||||||
|
9.09000015258789,
|
||||||
|
11.579999923706055,
|
||||||
|
15.65999984741211,
|
||||||
|
15.769999504089355,
|
||||||
|
15.789999961853027,
|
||||||
|
18.360000610351562,
|
||||||
|
21.989999771118164,
|
||||||
|
23.079999923706055,
|
||||||
|
30.009998321533203,
|
||||||
|
32.35000228881836,
|
||||||
|
32.590003967285156,
|
||||||
|
35.56000518798828,
|
||||||
|
39.95000457763672,
|
||||||
|
53.840003967285156,
|
||||||
|
56.20000457763672,
|
||||||
|
57.95000457763672,
|
||||||
|
59.29000473022461,
|
||||||
|
59.77000427246094,
|
||||||
|
59.920005798339844,
|
||||||
|
61.190006256103516,
|
||||||
|
61.96000671386719,
|
||||||
|
62.50000762939453,
|
||||||
|
63.3700065612793,
|
||||||
|
63.48000717163086,
|
||||||
|
63.48000717163086,
|
||||||
|
63.66000747680664,
|
||||||
|
63.850006103515625,
|
||||||
|
64.08000946044922,
|
||||||
|
64.760009765625,
|
||||||
|
64.80001068115234,
|
||||||
|
64.81001281738281,
|
||||||
|
64.81001281738281
|
||||||
|
],
|
||||||
|
"short_factor": [
|
||||||
|
1.05,
|
||||||
|
1.05,
|
||||||
|
1.05,
|
||||||
|
1.1,
|
||||||
|
1.1,
|
||||||
|
1.1,
|
||||||
|
1.2500000000000002,
|
||||||
|
1.2500000000000002,
|
||||||
|
1.4000000000000004,
|
||||||
|
1.4500000000000004,
|
||||||
|
1.5500000000000005,
|
||||||
|
1.8500000000000008,
|
||||||
|
1.9000000000000008,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.000000000000001,
|
||||||
|
2.1000000000000005,
|
||||||
|
2.1000000000000005,
|
||||||
|
2.2,
|
||||||
|
2.3499999999999996,
|
||||||
|
2.3499999999999996,
|
||||||
|
2.3499999999999996,
|
||||||
|
2.3499999999999996,
|
||||||
|
2.3999999999999995,
|
||||||
|
2.3999999999999995,
|
||||||
|
2.6499999999999986,
|
||||||
|
2.6999999999999984,
|
||||||
|
2.8999999999999977,
|
||||||
|
2.9499999999999975,
|
||||||
|
3.049999999999997,
|
||||||
|
3.049999999999997,
|
||||||
|
3.049999999999997
|
||||||
|
],
|
||||||
|
"type": "su"
|
||||||
|
},
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"sliding_window": 131072,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.38.1",
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 32064,
|
||||||
|
"_attn_implementation": "sdpa"
|
||||||
|
}
|
||||||
|
config = Phi3Config(**config)
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||||
|
input_is_list = isinstance(x, list)
|
||||||
|
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||||
|
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
if input_img_latents is not None:
|
||||||
|
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||||
|
if input_ids is not None:
|
||||||
|
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||||
|
input_img_inx = 0
|
||||||
|
for b_inx in input_image_sizes.keys():
|
||||||
|
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||||
|
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||||
|
input_img_inx += 1
|
||||||
|
if input_img_latents is not None:
|
||||||
|
assert input_img_inx == len(input_latents)
|
||||||
|
|
||||||
|
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||||
|
else:
|
||||||
|
input_emb = torch.cat([time_token, x], dim=1)
|
||||||
|
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||||
|
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||||
|
if input_is_list:
|
||||||
|
image_embedding = output[:, -max(num_tokens):]
|
||||||
|
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||||
|
x = self.final_layer(image_embedding, time_emb)
|
||||||
|
latents = []
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
latent = x[i:i+1, :num_tokens[i]]
|
||||||
|
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||||
|
latents.append(latent)
|
||||||
|
else:
|
||||||
|
image_embedding = output[:, -num_tokens:]
|
||||||
|
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||||
|
x = self.final_layer(image_embedding, time_emb)
|
||||||
|
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||||
|
|
||||||
|
if return_past_key_values:
|
||||||
|
return latents, past_key_values
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||||
|
self.llm.config.use_cache = use_kv_cache
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = [None] * len(attention_mask)
|
||||||
|
|
||||||
|
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||||
|
timestep = timestep.to(x[0].dtype)
|
||||||
|
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||||
|
|
||||||
|
model_out, pask_key_values = [], []
|
||||||
|
for i in range(len(input_ids)):
|
||||||
|
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||||
|
model_out.append(temp_out)
|
||||||
|
pask_key_values.append(temp_pask_key_values)
|
||||||
|
|
||||||
|
if len(model_out) == 3:
|
||||||
|
cond, uncond, img_cond = model_out
|
||||||
|
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||||
|
model_out = [cond, cond, cond]
|
||||||
|
elif len(model_out) == 2:
|
||||||
|
cond, uncond = model_out
|
||||||
|
cond = uncond + cfg_scale * (cond - uncond)
|
||||||
|
model_out = [cond, cond]
|
||||||
|
else:
|
||||||
|
return model_out[0]
|
||||||
|
|
||||||
|
return torch.cat(model_out, dim=0), pask_key_values
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return OmniGenTransformerStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class OmniGenTransformerStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
168
diffsynth/models/qwenvl.py
Normal file
168
diffsynth/models/qwenvl.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
||||||
|
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||||
|
super(Qwen25VL_7b_Embedder, self).__init__()
|
||||||
|
self.max_length = max_length
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
).to(torch.cuda.current_device())
|
||||||
|
|
||||||
|
self.model.requires_grad_(False)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
||||||
|
)
|
||||||
|
|
||||||
|
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
||||||
|
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
||||||
|
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
||||||
|
Here are examples of how to transform or refine prompts:
|
||||||
|
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
||||||
|
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
||||||
|
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
||||||
|
User Prompt:'''
|
||||||
|
|
||||||
|
self.prefix = Qwen25VL_7b_PREFIX
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
|
||||||
|
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, caption, ref_images):
|
||||||
|
text_list = caption
|
||||||
|
embs = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
masks = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
input_ids_list = []
|
||||||
|
attention_mask_list = []
|
||||||
|
emb_list = []
|
||||||
|
|
||||||
|
def split_string(s):
|
||||||
|
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
||||||
|
result = []
|
||||||
|
in_quotes = False
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
for idx,char in enumerate(s):
|
||||||
|
if char == '"' and idx>155:
|
||||||
|
temp += char
|
||||||
|
if not in_quotes:
|
||||||
|
result.append(temp)
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
in_quotes = not in_quotes
|
||||||
|
continue
|
||||||
|
if in_quotes:
|
||||||
|
if char.isspace():
|
||||||
|
pass # have space token
|
||||||
|
|
||||||
|
result.append("“" + char + "”")
|
||||||
|
else:
|
||||||
|
temp += char
|
||||||
|
|
||||||
|
if temp:
|
||||||
|
result.append(temp)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": []}]
|
||||||
|
|
||||||
|
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
||||||
|
|
||||||
|
messages[0]["content"].append({"type": "image", "image": imgs})
|
||||||
|
|
||||||
|
# 再添加 text
|
||||||
|
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
||||||
|
|
||||||
|
# Preparation for inference
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs = [imgs]
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=image_inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
old_inputs_ids = inputs.input_ids
|
||||||
|
text_split_list = split_string(text)
|
||||||
|
|
||||||
|
token_list = []
|
||||||
|
for text_each in text_split_list:
|
||||||
|
txt_inputs = self.processor(
|
||||||
|
text=text_each,
|
||||||
|
images=None,
|
||||||
|
videos=None,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
token_each = txt_inputs.input_ids
|
||||||
|
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
||||||
|
token_each = token_each[:, 1:-1]
|
||||||
|
token_list.append(token_each)
|
||||||
|
else:
|
||||||
|
token_list.append(token_each)
|
||||||
|
|
||||||
|
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||||
|
|
||||||
|
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||||
|
|
||||||
|
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||||
|
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||||
|
inputs.input_ids = (
|
||||||
|
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to("cuda")
|
||||||
|
)
|
||||||
|
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=inputs.input_ids,
|
||||||
|
attention_mask=inputs.attention_mask,
|
||||||
|
pixel_values=inputs.pixel_values.to("cuda"),
|
||||||
|
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
emb = outputs["hidden_states"][-1]
|
||||||
|
|
||||||
|
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
||||||
|
: self.max_length
|
||||||
|
]
|
||||||
|
|
||||||
|
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||||
|
(min(self.max_length, emb.shape[1] - 217)),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return embs, masks
|
||||||
551
diffsynth/models/sd3_dit.py
Normal file
551
diffsynth/models/sd3_dit.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from .svd_unet import TemporalTimesteps
|
||||||
|
from .tiler import TileWorker
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, eps, elementwise_affine=True):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype)
|
||||||
|
if self.weight is not None:
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(torch.nn.Module):
|
||||||
|
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
||||||
|
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
|
||||||
|
|
||||||
|
def cropped_pos_embed(self, height, width):
|
||||||
|
height = height // self.patch_size
|
||||||
|
width = width // self.patch_size
|
||||||
|
top = (self.pos_embed_max_size - height) // 2
|
||||||
|
left = (self.pos_embed_max_size - width) // 2
|
||||||
|
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
||||||
|
return spatial_pos_embed
|
||||||
|
|
||||||
|
def forward(self, latent):
|
||||||
|
height, width = latent.shape[-2:]
|
||||||
|
latent = self.proj(latent)
|
||||||
|
latent = latent.flatten(2).transpose(1, 2)
|
||||||
|
pos_embed = self.cropped_pos_embed(height, width)
|
||||||
|
return latent + pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, computation_device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||||
|
self.timestep_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, timestep, dtype):
|
||||||
|
time_emb = self.time_proj(timestep).to(dtype)
|
||||||
|
time_emb = self.timestep_embedder(time_emb)
|
||||||
|
return time_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, single=False, dual=False):
|
||||||
|
super().__init__()
|
||||||
|
self.single = single
|
||||||
|
self.dual = dual
|
||||||
|
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(torch.nn.functional.silu(emb))
|
||||||
|
if self.single:
|
||||||
|
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
elif self.dual:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||||
|
norm_x = self.norm(x)
|
||||||
|
x = norm_x * (1 + scale_msa) + shift_msa
|
||||||
|
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||||
|
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class JointAttention(torch.nn.Module):
|
||||||
|
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.only_out_a = only_out_a
|
||||||
|
|
||||||
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||||
|
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||||
|
|
||||||
|
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||||
|
if not only_out_a:
|
||||||
|
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||||
|
|
||||||
|
if use_rms_norm:
|
||||||
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
else:
|
||||||
|
self.norm_q_a = None
|
||||||
|
self.norm_k_a = None
|
||||||
|
self.norm_q_b = None
|
||||||
|
self.norm_k_b = None
|
||||||
|
|
||||||
|
|
||||||
|
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
qkv = to_qkv(hidden_states)
|
||||||
|
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
if norm_q is not None:
|
||||||
|
q = norm_q(q)
|
||||||
|
if norm_k is not None:
|
||||||
|
k = norm_k(k)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b):
|
||||||
|
batch_size = hidden_states_a.shape[0]
|
||||||
|
|
||||||
|
qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||||
|
qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
|
||||||
|
q = torch.concat([qa, qb], dim=2)
|
||||||
|
k = torch.concat([ka, kb], dim=2)
|
||||||
|
v = torch.concat([va, vb], dim=2)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
||||||
|
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||||
|
if self.only_out_a:
|
||||||
|
return hidden_states_a
|
||||||
|
else:
|
||||||
|
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SingleAttention(torch.nn.Module):
|
||||||
|
def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||||
|
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||||
|
|
||||||
|
if use_rms_norm:
|
||||||
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
else:
|
||||||
|
self.norm_q_a = None
|
||||||
|
self.norm_k_a = None
|
||||||
|
|
||||||
|
|
||||||
|
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
qkv = to_qkv(hidden_states)
|
||||||
|
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
if norm_q is not None:
|
||||||
|
q = norm_q(q)
|
||||||
|
if norm_k is not None:
|
||||||
|
k = norm_k(k)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a):
|
||||||
|
batch_size = hidden_states_a.shape[0]
|
||||||
|
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states = self.a_to_out(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DualTransformerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_a = AdaLayerNorm(dim, dual=True)
|
||||||
|
self.norm1_b = AdaLayerNorm(dim)
|
||||||
|
|
||||||
|
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||||
|
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||||
|
|
||||||
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_a = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_b = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||||
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||||
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||||
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||||
|
|
||||||
|
# Part B
|
||||||
|
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||||
|
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||||
|
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class JointTransformerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_a = AdaLayerNorm(dim, dual=dual)
|
||||||
|
self.norm1_b = AdaLayerNorm(dim)
|
||||||
|
|
||||||
|
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||||
|
if dual:
|
||||||
|
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||||
|
|
||||||
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_a = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_b = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||||
|
if self.norm1_a.dual:
|
||||||
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
|
else:
|
||||||
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||||
|
if self.norm1_a.dual:
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||||
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||||
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||||
|
|
||||||
|
# Part B
|
||||||
|
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||||
|
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||||
|
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class JointTransformerFinalBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_a = AdaLayerNorm(dim)
|
||||||
|
self.norm1_b = AdaLayerNorm(dim, single=True)
|
||||||
|
|
||||||
|
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
|
||||||
|
|
||||||
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_a = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||||
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
|
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||||
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||||
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SD3DiT(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
|
||||||
|
self.time_embedder = TimestepEmbeddings(256, embed_dim)
|
||||||
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
|
||||||
|
self.context_embedder = torch.nn.Linear(4096, embed_dim)
|
||||||
|
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
|
||||||
|
+ [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
|
||||||
|
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
|
||||||
|
self.norm_out = AdaLayerNorm(embed_dim, single=True)
|
||||||
|
self.proj_out = torch.nn.Linear(embed_dim, 64)
|
||||||
|
|
||||||
|
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
||||||
|
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
||||||
|
hidden_states,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=hidden_states.device,
|
||||||
|
tile_dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
||||||
|
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||||
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
|
|
||||||
|
height, width = hidden_states.shape[-2:]
|
||||||
|
hidden_states = self.pos_embedder(hidden_states)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, prompt_emb, conditioning,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SD3DiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SD3DiTStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def infer_architecture(self, state_dict):
|
||||||
|
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
|
||||||
|
num_layers = 100
|
||||||
|
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
|
||||||
|
num_layers -= 1
|
||||||
|
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
|
||||||
|
num_dual_blocks = 0
|
||||||
|
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
|
||||||
|
num_dual_blocks += 1
|
||||||
|
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
|
||||||
|
return {
|
||||||
|
"embed_dim": embed_dim,
|
||||||
|
"num_layers": num_layers,
|
||||||
|
"use_rms_norm": use_rms_norm,
|
||||||
|
"num_dual_blocks": num_dual_blocks,
|
||||||
|
"pos_embed_max_size": pos_embed_max_size
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"context_embedder": "context_embedder",
|
||||||
|
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
||||||
|
"pos_embed.proj": "pos_embedder.proj",
|
||||||
|
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||||
|
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||||
|
"norm_out.linear": "norm_out.linear",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
|
||||||
|
"norm1.linear": "norm1_a.linear",
|
||||||
|
"norm1_context.linear": "norm1_b.linear",
|
||||||
|
"attn.to_q": "attn.a_to_q",
|
||||||
|
"attn.to_k": "attn.a_to_k",
|
||||||
|
"attn.to_v": "attn.a_to_v",
|
||||||
|
"attn.to_out.0": "attn.a_to_out",
|
||||||
|
"attn.add_q_proj": "attn.b_to_q",
|
||||||
|
"attn.add_k_proj": "attn.b_to_k",
|
||||||
|
"attn.add_v_proj": "attn.b_to_v",
|
||||||
|
"attn.to_add_out": "attn.b_to_out",
|
||||||
|
"ff.net.0.proj": "ff_a.0",
|
||||||
|
"ff.net.2": "ff_a.2",
|
||||||
|
"ff_context.net.0.proj": "ff_b.0",
|
||||||
|
"ff_context.net.2": "ff_b.2",
|
||||||
|
|
||||||
|
"attn.norm_q": "attn.norm_q_a",
|
||||||
|
"attn.norm_k": "attn.norm_k_a",
|
||||||
|
"attn.norm_added_q": "attn.norm_q_b",
|
||||||
|
"attn.norm_added_k": "attn.norm_k_b",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in rename_dict:
|
||||||
|
if name == "pos_embed.pos_embed":
|
||||||
|
param = param.reshape((1, 192, 192, param.shape[-1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||||
|
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||||
|
prefix = name[:-len(suffix)]
|
||||||
|
if prefix in rename_dict:
|
||||||
|
state_dict_[rename_dict[prefix] + suffix] = param
|
||||||
|
elif prefix.startswith("transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
|
||||||
|
for key in merged_keys:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[key.replace("to_q", "to_q")],
|
||||||
|
state_dict_[key.replace("to_q", "to_k")],
|
||||||
|
state_dict_[key.replace("to_q", "to_v")],
|
||||||
|
], dim=0)
|
||||||
|
name = key.replace("to_q", "to_qkv")
|
||||||
|
state_dict_.pop(key.replace("to_q", "to_q"))
|
||||||
|
state_dict_.pop(key.replace("to_q", "to_k"))
|
||||||
|
state_dict_.pop(key.replace("to_q", "to_v"))
|
||||||
|
state_dict_[name] = param
|
||||||
|
return state_dict_, self.infer_architecture(state_dict_)
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
||||||
|
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
||||||
|
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
||||||
|
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
||||||
|
|
||||||
|
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
||||||
|
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
||||||
|
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
||||||
|
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||||
|
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||||
|
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
||||||
|
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
||||||
|
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
||||||
|
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
||||||
|
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
||||||
|
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
||||||
|
|
||||||
|
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
||||||
|
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
||||||
|
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||||
|
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||||
|
}
|
||||||
|
for i in range(40):
|
||||||
|
rename_dict.update({
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
|
||||||
|
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
|
||||||
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
|
||||||
|
})
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "model.diffusion_model.pos_embed":
|
||||||
|
pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
|
||||||
|
param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
|
||||||
|
if isinstance(rename_dict[name], str):
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
else:
|
||||||
|
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
||||||
|
state_dict_[name_] = param
|
||||||
|
extra_kwargs = self.infer_architecture(state_dict_)
|
||||||
|
num_layers = extra_kwargs["num_layers"]
|
||||||
|
for name in [
|
||||||
|
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
|
||||||
|
]:
|
||||||
|
param = state_dict_[name]
|
||||||
|
dim = param.shape[0] // 2
|
||||||
|
param = torch.concat([param[dim:], param[:dim]], axis=0)
|
||||||
|
state_dict_[name] = param
|
||||||
|
return state_dict_, self.infer_architecture(state_dict_)
|
||||||
1120
diffsynth/models/sd3_text_encoder.py
Normal file
1120
diffsynth/models/sd3_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
81
diffsynth/models/sd3_vae_decoder.py
Normal file
81
diffsynth/models/sd3_vae_decoder.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
||||||
|
from .sd_unet import ResnetBlock, UpSampler
|
||||||
|
from .tiler import TileWorker
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SD3VAEDecoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||||
|
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||||
|
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# UNetMidBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
UpSampler(512),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
UpSampler(512),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
UpSampler(256),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(256, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x),
|
||||||
|
sample,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=sample.device,
|
||||||
|
tile_dtype=sample.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
# 1. pre-process
|
||||||
|
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
time_emb = None
|
||||||
|
text_emb = None
|
||||||
|
res_stack = None
|
||||||
|
|
||||||
|
# 2. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 3. output
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDVAEDecoderStateDictConverter()
|
||||||
95
diffsynth/models/sd3_vae_encoder.py
Normal file
95
diffsynth/models/sd3_vae_encoder.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_unet import ResnetBlock, DownSampler
|
||||||
|
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
||||||
|
from .tiler import TileWorker
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class SD3VAEEncoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||||
|
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||||
|
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
DownSampler(128, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(128, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
DownSampler(256, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(256, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
DownSampler(512, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
# UNetMidBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x),
|
||||||
|
sample,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=sample.device,
|
||||||
|
tile_dtype=sample.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
# 1. pre-process
|
||||||
|
hidden_states = self.conv_in(sample)
|
||||||
|
time_emb = None
|
||||||
|
text_emb = None
|
||||||
|
res_stack = None
|
||||||
|
|
||||||
|
# 2. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 3. output
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, :16]
|
||||||
|
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def encode_video(self, sample, batch_size=8):
|
||||||
|
B = sample.shape[0]
|
||||||
|
hidden_states = []
|
||||||
|
|
||||||
|
for i in range(0, sample.shape[2], batch_size):
|
||||||
|
|
||||||
|
j = min(i + batch_size, sample.shape[2])
|
||||||
|
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||||
|
|
||||||
|
hidden_states_batch = self(sample_batch)
|
||||||
|
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||||
|
|
||||||
|
hidden_states.append(hidden_states_batch)
|
||||||
|
|
||||||
|
hidden_states = torch.concat(hidden_states, dim=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDVAEEncoderStateDictConverter()
|
||||||
@@ -97,9 +97,10 @@ class SDControlNet(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
sample, timestep, encoder_hidden_states, conditioning,
|
sample, timestep, encoder_hidden_states, conditioning,
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
tiled=False, tile_size=64, tile_stride=32,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# 1. time
|
# 1. time
|
||||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||||
time_emb = self.time_embedding(time_emb)
|
time_emb = self.time_embedding(time_emb)
|
||||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||||
|
|
||||||
@@ -134,7 +135,8 @@ class SDControlNet(torch.nn.Module):
|
|||||||
|
|
||||||
return controlnet_res_stack
|
return controlnet_res_stack
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDControlNetStateDictConverter()
|
return SDControlNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class SDIpAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
def set_less_adapter(self):
|
def set_less_adapter(self):
|
||||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
# IP-Adapter for SD v1.5 doesn't support this feature.
|
||||||
self.set_full_adapter(self)
|
self.set_full_adapter()
|
||||||
|
|
||||||
def forward(self, hidden_states, scale=1.0):
|
def forward(self, hidden_states, scale=1.0):
|
||||||
hidden_states = self.image_proj(hidden_states)
|
hidden_states = self.image_proj(hidden_states)
|
||||||
@@ -47,7 +47,8 @@ class SDIpAdapter(torch.nn.Module):
|
|||||||
}
|
}
|
||||||
return ip_kv_dict
|
return ip_kv_dict
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDIpAdapterStateDictConverter()
|
return SDIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
|
||||||
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
|
||||||
|
|
||||||
|
|
||||||
class SDLoRA:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
|
||||||
special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_up" not in key:
|
|
||||||
continue
|
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
|
||||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
|
||||||
for special_key in special_keys:
|
|
||||||
target_name = target_name.replace(special_key, special_keys[special_key])
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
|
||||||
state_dict_unet = unet.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
|
||||||
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
for name in state_dict_lora:
|
|
||||||
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
|
||||||
unet.load_state_dict(state_dict_unet)
|
|
||||||
|
|
||||||
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
|
||||||
state_dict_text_encoder = text_encoder.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
|
||||||
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
for name in state_dict_lora:
|
|
||||||
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
|
||||||
text_encoder.load_state_dict(state_dict_text_encoder)
|
|
||||||
|
|
||||||
@@ -1,28 +1,20 @@
|
|||||||
from .sd_unet import SDUNet, Attention, GEGLU
|
from .sd_unet import SDUNet, Attention, GEGLU
|
||||||
from .svd_unet import get_timestep_embedding
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
class TemporalTransformerBlock(torch.nn.Module):
|
class TemporalTransformerBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
|
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.add_positional_conv = add_positional_conv
|
|
||||||
|
|
||||||
# 1. Self-Attn
|
# 1. Self-Attn
|
||||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||||
self.pe1 = torch.nn.Parameter(emb)
|
|
||||||
if add_positional_conv:
|
|
||||||
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
|
||||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||||
|
|
||||||
# 2. Cross-Attn
|
# 2. Cross-Attn
|
||||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||||
self.pe2 = torch.nn.Parameter(emb)
|
|
||||||
if add_positional_conv:
|
|
||||||
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
|
||||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||||
|
|
||||||
@@ -32,47 +24,19 @@ class TemporalTransformerBlock(torch.nn.Module):
|
|||||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||||
|
|
||||||
|
|
||||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
|
||||||
if frame_id < max_id:
|
|
||||||
position_id = frame_id
|
|
||||||
else:
|
|
||||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
|
||||||
if position_id < repeat_length:
|
|
||||||
position_id = max_id - 2 - position_id
|
|
||||||
else:
|
|
||||||
position_id = max_id - 2 * repeat_length + position_id
|
|
||||||
return position_id
|
|
||||||
|
|
||||||
|
|
||||||
def positional_ids(self, num_frames):
|
|
||||||
max_id = self.pe1.shape[1]
|
|
||||||
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
|
||||||
return positional_ids
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, batch_size=1):
|
def forward(self, hidden_states, batch_size=1):
|
||||||
|
|
||||||
# 1. Self-Attention
|
# 1. Self-Attention
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||||
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
|
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
||||||
if self.add_positional_conv:
|
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
|
||||||
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
|
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
|
||||||
attn_output = self.attn1(norm_hidden_states)
|
|
||||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
# 2. Cross-Attention
|
# 2. Cross-Attention
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||||
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
|
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
||||||
if self.add_positional_conv:
|
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
|
||||||
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
|
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
|
||||||
attn_output = self.attn2(norm_hidden_states)
|
|
||||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
@@ -87,7 +51,7 @@ class TemporalTransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
class TemporalBlock(torch.nn.Module):
|
class TemporalBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
@@ -98,9 +62,7 @@ class TemporalBlock(torch.nn.Module):
|
|||||||
TemporalTransformerBlock(
|
TemporalTransformerBlock(
|
||||||
inner_dim,
|
inner_dim,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_dim,
|
attention_head_dim
|
||||||
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
|
|
||||||
add_positional_conv=add_positional_conv
|
|
||||||
)
|
)
|
||||||
for d in range(num_layers)
|
for d in range(num_layers)
|
||||||
])
|
])
|
||||||
@@ -130,30 +92,30 @@ class TemporalBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SDMotionModel(torch.nn.Module):
|
class SDMotionModel(torch.nn.Module):
|
||||||
def __init__(self, add_positional_conv=None):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.motion_modules = torch.nn.ModuleList([
|
self.motion_modules = torch.nn.ModuleList([
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||||
])
|
])
|
||||||
self.call_block_id = {
|
self.call_block_id = {
|
||||||
1: 0,
|
1: 0,
|
||||||
@@ -182,7 +144,8 @@ class SDMotionModel(torch.nn.Module):
|
|||||||
def forward(self):
|
def forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDMotionModelStateDictConverter()
|
return SDMotionModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -190,42 +153,7 @@ class SDMotionModelStateDictConverter:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
def from_diffusers(self, state_dict):
|
||||||
if frame_id < max_id:
|
|
||||||
position_id = frame_id
|
|
||||||
else:
|
|
||||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
|
||||||
if position_id < repeat_length:
|
|
||||||
position_id = max_id - 2 - position_id
|
|
||||||
else:
|
|
||||||
position_id = max_id - 2 * repeat_length + position_id
|
|
||||||
return position_id
|
|
||||||
|
|
||||||
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
|
|
||||||
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
|
|
||||||
for i in range(21):
|
|
||||||
# Extend positional embedding
|
|
||||||
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
|
|
||||||
state_dict[name] = state_dict[name][:, ids]
|
|
||||||
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
|
|
||||||
state_dict[name] = state_dict[name][:, ids]
|
|
||||||
# add post convolution
|
|
||||||
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
|
|
||||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
|
|
||||||
state_dict[name] = torch.zeros((dim,))
|
|
||||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
|
|
||||||
state_dict[name] = torch.zeros((dim,))
|
|
||||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
|
|
||||||
param = torch.zeros((dim, dim, 3))
|
|
||||||
param[:, :, 1] = torch.eye(dim, dim)
|
|
||||||
state_dict[name] = param
|
|
||||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
|
|
||||||
param = torch.zeros((dim, dim, 3))
|
|
||||||
param[:, :, 1] = torch.eye(dim, dim)
|
|
||||||
state_dict[name] = param
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict, add_positional_conv=None):
|
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"norm": "norm",
|
"norm": "norm",
|
||||||
"proj_in": "proj_in",
|
"proj_in": "proj_in",
|
||||||
@@ -265,9 +193,7 @@ class SDMotionModelStateDictConverter:
|
|||||||
else:
|
else:
|
||||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||||
state_dict_[rename] = state_dict[name]
|
state_dict_[rename] = state_dict[name]
|
||||||
if add_positional_conv is not None:
|
|
||||||
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
|
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
def from_civitai(self, state_dict, add_positional_conv=None):
|
def from_civitai(self, state_dict):
|
||||||
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)
|
return self.from_diffusers(state_dict)
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
from .attention import Attention
|
|
||||||
from .svd_unet import get_timestep_embedding
|
|
||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExVideoMotionBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
|
|
||||||
self.positional_embedding = torch.nn.Parameter(emb)
|
|
||||||
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
|
|
||||||
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
|
|
||||||
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
|
|
||||||
|
|
||||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
|
||||||
if frame_id < max_id:
|
|
||||||
position_id = frame_id
|
|
||||||
else:
|
|
||||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
|
||||||
if position_id < repeat_length:
|
|
||||||
position_id = max_id - 2 - position_id
|
|
||||||
else:
|
|
||||||
position_id = max_id - 2 * repeat_length + position_id
|
|
||||||
return position_id
|
|
||||||
|
|
||||||
def positional_ids(self, num_frames):
|
|
||||||
max_id = self.positional_embedding.shape[0]
|
|
||||||
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
|
||||||
return positional_ids
|
|
||||||
|
|
||||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
|
|
||||||
batch, inner_dim, height, width = hidden_states.shape
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
pos_emb = self.positional_ids(batch // batch_size)
|
|
||||||
pos_emb = self.positional_embedding[pos_emb]
|
|
||||||
pos_emb = pos_emb.repeat(batch_size)
|
|
||||||
hidden_states = hidden_states + pos_emb
|
|
||||||
if self.positional_conv is not None:
|
|
||||||
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
|
|
||||||
hidden_states = self.positional_conv(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
|
|
||||||
else:
|
|
||||||
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
|
|
||||||
|
|
||||||
for norm, attn in zip(self.norms, self.attns):
|
|
||||||
norm_hidden_states = norm(hidden_states)
|
|
||||||
attn_output = attn(norm_hidden_states)
|
|
||||||
hidden_states = hidden_states + attn_output
|
|
||||||
|
|
||||||
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
return hidden_states, time_emb, text_emb, res_stack
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExVideoMotionModel(torch.nn.Module):
|
|
||||||
def __init__(self, num_layers=2):
|
|
||||||
super().__init__()
|
|
||||||
self.motion_modules = torch.nn.ModuleList([
|
|
||||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
||||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
||||||
])
|
|
||||||
self.call_block_id = {
|
|
||||||
1: 0,
|
|
||||||
4: 1,
|
|
||||||
9: 2,
|
|
||||||
12: 3,
|
|
||||||
17: 4,
|
|
||||||
20: 5,
|
|
||||||
24: 6,
|
|
||||||
26: 7,
|
|
||||||
29: 8,
|
|
||||||
32: 9,
|
|
||||||
34: 10,
|
|
||||||
36: 11,
|
|
||||||
40: 12,
|
|
||||||
43: 13,
|
|
||||||
46: 14,
|
|
||||||
50: 15,
|
|
||||||
53: 16,
|
|
||||||
56: 17,
|
|
||||||
60: 18,
|
|
||||||
63: 19,
|
|
||||||
66: 20
|
|
||||||
}
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def state_dict_converter(self):
|
|
||||||
pass
|
|
||||||
@@ -71,7 +71,8 @@ class SDTextEncoder(torch.nn.Module):
|
|||||||
embeds = self.final_layer_norm(embeds)
|
embeds = self.final_layer_norm(embeds)
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDTextEncoderStateDictConverter()
|
return SDTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||||
# 1. time
|
# 1. time
|
||||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||||
time_emb = self.time_embedding(time_emb)
|
time_emb = self.time_embedding(time_emb)
|
||||||
|
|
||||||
# 2. pre-process
|
# 2. pre-process
|
||||||
@@ -342,7 +342,8 @@ class SDUNet(torch.nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDUNetStateDictConverter()
|
return SDUNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,8 @@ class SDVAEDecoder(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
original_dtype = sample.dtype
|
||||||
|
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
if tiled:
|
if tiled:
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
@@ -110,10 +112,12 @@ class SDVAEDecoder(torch.nn.Module):
|
|||||||
hidden_states = self.conv_norm_out(hidden_states)
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
hidden_states = self.conv_act(hidden_states)
|
hidden_states = self.conv_act(hidden_states)
|
||||||
hidden_states = self.conv_out(hidden_states)
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
hidden_states = hidden_states.to(original_dtype)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDVAEDecoderStateDictConverter()
|
return SDVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,8 @@ class SDVAEEncoder(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
original_dtype = sample.dtype
|
||||||
|
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
if tiled:
|
if tiled:
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
@@ -71,6 +73,7 @@ class SDVAEEncoder(torch.nn.Module):
|
|||||||
hidden_states = self.quant_conv(hidden_states)
|
hidden_states = self.quant_conv(hidden_states)
|
||||||
hidden_states = hidden_states[:, :4]
|
hidden_states = hidden_states[:, :4]
|
||||||
hidden_states *= self.scaling_factor
|
hidden_states *= self.scaling_factor
|
||||||
|
hidden_states = hidden_states.to(original_dtype)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -91,7 +94,8 @@ class SDVAEEncoder(torch.nn.Module):
|
|||||||
hidden_states = torch.concat(hidden_states, dim=2)
|
hidden_states = torch.concat(hidden_states, dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDVAEEncoderStateDictConverter()
|
return SDVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
318
diffsynth/models/sdxl_controlnet.py
Normal file
318
diffsynth/models/sdxl_controlnet.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||||
|
from .sdxl_unet import SDXLUNet
|
||||||
|
from .tiler import TileWorker
|
||||||
|
from .sd_controlnet import ControlNetConditioningLayer
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
|
||||||
|
self.ln_1 = torch.nn.LayerNorm(d_model)
|
||||||
|
self.mlp = torch.nn.Sequential(OrderedDict([
|
||||||
|
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
|
||||||
|
("gelu", QuickGELU()),
|
||||||
|
("c_proj", torch.nn.Linear(d_model * 4, d_model))
|
||||||
|
]))
|
||||||
|
self.ln_2 = torch.nn.LayerNorm(d_model)
|
||||||
|
self.attn_mask = attn_mask
|
||||||
|
|
||||||
|
def attention(self, x: torch.Tensor):
|
||||||
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||||
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x + self.attention(self.ln_1(x))
|
||||||
|
x = x + self.mlp(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLControlNetUnion(torch.nn.Module):
|
||||||
|
def __init__(self, global_pool=False):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = Timesteps(320)
|
||||||
|
self.time_embedding = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(320, 1280),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(1280, 1280)
|
||||||
|
)
|
||||||
|
self.add_time_proj = Timesteps(256)
|
||||||
|
self.add_time_embedding = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(2816, 1280),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(1280, 1280)
|
||||||
|
)
|
||||||
|
self.control_type_proj = Timesteps(256)
|
||||||
|
self.control_type_embedding = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(256 * 8, 1280),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(1280, 1280)
|
||||||
|
)
|
||||||
|
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
||||||
|
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
|
||||||
|
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
|
||||||
|
self.spatial_ch_projs = torch.nn.Linear(320, 320)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# DownBlock2D
|
||||||
|
ResnetBlock(320, 320, 1280),
|
||||||
|
PushBlock(),
|
||||||
|
ResnetBlock(320, 320, 1280),
|
||||||
|
PushBlock(),
|
||||||
|
DownSampler(320),
|
||||||
|
PushBlock(),
|
||||||
|
# CrossAttnDownBlock2D
|
||||||
|
ResnetBlock(320, 640, 1280),
|
||||||
|
AttentionBlock(10, 64, 640, 2, 2048),
|
||||||
|
PushBlock(),
|
||||||
|
ResnetBlock(640, 640, 1280),
|
||||||
|
AttentionBlock(10, 64, 640, 2, 2048),
|
||||||
|
PushBlock(),
|
||||||
|
DownSampler(640),
|
||||||
|
PushBlock(),
|
||||||
|
# CrossAttnDownBlock2D
|
||||||
|
ResnetBlock(640, 1280, 1280),
|
||||||
|
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||||
|
PushBlock(),
|
||||||
|
ResnetBlock(1280, 1280, 1280),
|
||||||
|
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||||
|
PushBlock(),
|
||||||
|
# UNetMidBlock2DCrossAttn
|
||||||
|
ResnetBlock(1280, 1280, 1280),
|
||||||
|
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||||
|
ResnetBlock(1280, 1280, 1280),
|
||||||
|
PushBlock()
|
||||||
|
])
|
||||||
|
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([
|
||||||
|
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||||
|
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.global_pool = global_pool
|
||||||
|
|
||||||
|
# 0 -- openpose
|
||||||
|
# 1 -- depth
|
||||||
|
# 2 -- hed/pidi/scribble/ted
|
||||||
|
# 3 -- canny/lineart/anime_lineart/mlsd
|
||||||
|
# 4 -- normal
|
||||||
|
# 5 -- segment
|
||||||
|
# 6 -- tile
|
||||||
|
# 7 -- repaint
|
||||||
|
self.task_id = {
|
||||||
|
"openpose": 0,
|
||||||
|
"depth": 1,
|
||||||
|
"softedge": 2,
|
||||||
|
"canny": 3,
|
||||||
|
"lineart": 3,
|
||||||
|
"lineart_anime": 3,
|
||||||
|
"tile": 6,
|
||||||
|
"inpaint": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
|
||||||
|
controlnet_cond = self.controlnet_conv_in(conditioning)
|
||||||
|
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||||
|
feat_seq = feat_seq + self.task_embedding[task_id]
|
||||||
|
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
|
||||||
|
x = self.controlnet_transformer(x)
|
||||||
|
|
||||||
|
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
controlnet_cond_fuser = controlnet_cond + alpha
|
||||||
|
|
||||||
|
hidden_states = hidden_states + controlnet_cond_fuser
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample, timestep, encoder_hidden_states,
|
||||||
|
conditioning, processor_id, add_time_id, add_text_embeds,
|
||||||
|
tiled=False, tile_size=64, tile_stride=32,
|
||||||
|
unet:SDXLUNet=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
task_id = self.task_id[processor_id]
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||||
|
t_emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
time_embeds = self.add_time_proj(add_time_id)
|
||||||
|
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
||||||
|
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
||||||
|
add_embeds = add_embeds.to(sample.dtype)
|
||||||
|
if unet is not None and unet.is_kolors:
|
||||||
|
add_embeds = unet.add_time_embedding(add_embeds)
|
||||||
|
else:
|
||||||
|
add_embeds = self.add_time_embedding(add_embeds)
|
||||||
|
|
||||||
|
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
|
||||||
|
control_type[:, task_id] = 1
|
||||||
|
control_embeds = self.control_type_proj(control_type.flatten())
|
||||||
|
control_embeds = control_embeds.reshape((sample.shape[0], -1))
|
||||||
|
control_embeds = control_embeds.to(sample.dtype)
|
||||||
|
control_embeds = self.control_type_embedding(control_embeds)
|
||||||
|
time_emb = t_emb + add_embeds + control_embeds
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
height, width = sample.shape[2], sample.shape[3]
|
||||||
|
hidden_states = self.conv_in(sample)
|
||||||
|
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
|
||||||
|
text_emb = encoder_hidden_states
|
||||||
|
if unet is not None and unet.is_kolors:
|
||||||
|
text_emb = unet.text_intermediate_proj(text_emb)
|
||||||
|
res_stack = [hidden_states]
|
||||||
|
|
||||||
|
# 3. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if tiled and not isinstance(block, PushBlock):
|
||||||
|
_, _, inter_height, _ = hidden_states.shape
|
||||||
|
resize_scale = inter_height / height
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||||
|
hidden_states,
|
||||||
|
int(tile_size * resize_scale),
|
||||||
|
int(tile_stride * resize_scale),
|
||||||
|
tile_device=hidden_states.device,
|
||||||
|
tile_dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 4. ControlNet blocks
|
||||||
|
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||||
|
|
||||||
|
# pool
|
||||||
|
if self.global_pool:
|
||||||
|
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
||||||
|
|
||||||
|
return controlnet_res_stack
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDXLControlNetUnionStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLControlNetUnionStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
# architecture
|
||||||
|
block_types = [
|
||||||
|
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
|
||||||
|
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
|
||||||
|
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
|
||||||
|
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
|
||||||
|
]
|
||||||
|
|
||||||
|
# controlnet_rename_dict
|
||||||
|
controlnet_rename_dict = {
|
||||||
|
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
||||||
|
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
||||||
|
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
||||||
|
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
||||||
|
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
||||||
|
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
||||||
|
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
||||||
|
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
||||||
|
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
||||||
|
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
||||||
|
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
||||||
|
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
||||||
|
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
||||||
|
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
||||||
|
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
||||||
|
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
||||||
|
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
|
||||||
|
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
|
||||||
|
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
|
||||||
|
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Rename each parameter
|
||||||
|
name_list = sorted([name for name in state_dict])
|
||||||
|
rename_dict = {}
|
||||||
|
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||||
|
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||||
|
for name in name_list:
|
||||||
|
names = name.split(".")
|
||||||
|
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
|
||||||
|
pass
|
||||||
|
elif name in controlnet_rename_dict:
|
||||||
|
names = controlnet_rename_dict[name].split(".")
|
||||||
|
elif names[0] == "controlnet_down_blocks":
|
||||||
|
names[0] = "controlnet_blocks"
|
||||||
|
elif names[0] == "controlnet_mid_block":
|
||||||
|
names = ["controlnet_blocks", "9", names[-1]]
|
||||||
|
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||||
|
if names[0] == "add_embedding":
|
||||||
|
names[0] = "add_time_embedding"
|
||||||
|
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
||||||
|
elif names[0] == "control_add_embedding":
|
||||||
|
names[0] = "control_type_embedding"
|
||||||
|
elif names[0] == "transformer_layes":
|
||||||
|
names[0] = "controlnet_transformer"
|
||||||
|
names.pop(1)
|
||||||
|
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
||||||
|
if names[0] == "mid_block":
|
||||||
|
names.insert(1, "0")
|
||||||
|
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
||||||
|
block_type_with_id = ".".join(names[:4])
|
||||||
|
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||||
|
block_id[block_type] += 1
|
||||||
|
last_block_type_with_id[block_type] = block_type_with_id
|
||||||
|
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||||
|
block_id[block_type] += 1
|
||||||
|
block_type_with_id = ".".join(names[:4])
|
||||||
|
names = ["blocks", str(block_id[block_type])] + names[4:]
|
||||||
|
if "ff" in names:
|
||||||
|
ff_index = names.index("ff")
|
||||||
|
component = ".".join(names[ff_index:ff_index+3])
|
||||||
|
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
||||||
|
names = names[:ff_index] + [component] + names[ff_index+3:]
|
||||||
|
if "to_out" in names:
|
||||||
|
names.pop(names.index("to_out") + 1)
|
||||||
|
else:
|
||||||
|
print(name, state_dict[name].shape)
|
||||||
|
# raise ValueError(f"Unknown parameters: {name}")
|
||||||
|
rename_dict[name] = ".".join(names)
|
||||||
|
|
||||||
|
# Convert state_dict
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in rename_dict:
|
||||||
|
continue
|
||||||
|
if ".proj_in." in name or ".proj_out." in name:
|
||||||
|
param = param.squeeze()
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
@@ -96,7 +96,8 @@ class SDXLIpAdapter(torch.nn.Module):
|
|||||||
}
|
}
|
||||||
return ip_kv_dict
|
return ip_kv_dict
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDXLIpAdapterStateDictConverter()
|
return SDXLIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ class SDXLMotionModel(torch.nn.Module):
|
|||||||
def forward(self):
|
def forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDMotionModelStateDictConverter()
|
return SDMotionModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,8 @@ class SDXLTextEncoder(torch.nn.Module):
|
|||||||
break
|
break
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDXLTextEncoderStateDictConverter()
|
return SDXLTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -80,7 +81,8 @@ class SDXLTextEncoder2(torch.nn.Module):
|
|||||||
pooled_embeds = self.text_projection(pooled_embeds)
|
pooled_embeds = self.text_projection(pooled_embeds)
|
||||||
return pooled_embeds, hidden_states
|
return pooled_embeds, hidden_states
|
||||||
|
|
||||||
def state_dict_converter(self):
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
return SDXLTextEncoder2StateDictConverter()
|
return SDXLTextEncoder2StateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user