mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
410 Commits
ExVideo
...
wan-lora-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdeae36cfb | ||
|
|
a2a720267e | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
a66203a391 | ||
|
|
fab61f614b | ||
|
|
6b67a11ad6 | ||
|
|
91f77d268c | ||
|
|
eb4d5187d8 | ||
|
|
ee4b02247c | ||
|
|
da8e1fe7e4 | ||
|
|
3db824c281 | ||
|
|
df2ecafd3f | ||
|
|
217652d28e | ||
|
|
f64c766dcd | ||
|
|
076fd85556 | ||
|
|
c7912ed827 | ||
|
|
e63f9d6993 | ||
|
|
d80ef3a677 | ||
|
|
852c3d831f | ||
|
|
ceb92ee7aa | ||
|
|
3a75026176 | ||
|
|
6a92b08244 | ||
|
|
38bc785ea9 | ||
|
|
a466fdca8f | ||
|
|
f9f49e3c78 | ||
|
|
61a30673c2 | ||
|
|
a48822ec00 | ||
|
|
b6c3d2b74a | ||
|
|
5006c2176c | ||
|
|
d3d3556ff6 | ||
|
|
6fa8dbe077 | ||
|
|
a57749ef60 | ||
|
|
b5c1d33e58 | ||
|
|
34a9f82865 | ||
|
|
18dc6cb962 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
c760208614 | ||
|
|
fad7aea58a | ||
|
|
b42eb1444c | ||
|
|
25a247dd3f | ||
|
|
7792017a02 | ||
|
|
0219e8d2f3 | ||
|
|
1d309a14a3 | ||
|
|
7df73ceaaf | ||
|
|
0dbb3d333f | ||
|
|
1419bec53d | ||
|
|
cf12723c89 | ||
|
|
4268f5466b | ||
|
|
b9f5a00d98 | ||
|
|
7d44dc99fb | ||
|
|
b20de1b44d | ||
|
|
366ee0f542 | ||
|
|
bed770248b | ||
|
|
020560d2b5 | ||
|
|
af7d305f00 | ||
|
|
427232cbc0 | ||
|
|
2899283c01 | ||
|
|
9cff769fbd | ||
|
|
23e33273f1 | ||
|
|
f191353cf4 | ||
|
|
66a094fc84 | ||
|
|
3681adc5ac | ||
|
|
4449faaa01 | ||
|
|
991ba162bd | ||
|
|
77d0f4d297 | ||
|
|
a834371d50 | ||
|
|
acda7d891a | ||
|
|
7434ec8fcd | ||
|
|
0699212665 | ||
|
|
f47de78b59 | ||
|
|
5fdc8039ec | ||
|
|
46d4616e23 | ||
|
|
2e597335be | ||
|
|
d346300162 | ||
|
|
1df7387f1b | ||
|
|
75d62a02d1 | ||
|
|
9db26879df | ||
|
|
7beac7972e | ||
|
|
72cac18d3e | ||
|
|
9f8112ec34 | ||
|
|
d9fad821b2 | ||
|
|
c0889c2564 | ||
|
|
913591c13e | ||
|
|
aaf13d6e4a | ||
|
|
90c07fec61 | ||
|
|
cc6c3c0807 | ||
|
|
ce2476ab9b | ||
|
|
9e70c49317 | ||
|
|
bf1c99645b | ||
|
|
c2478ff284 | ||
|
|
a60bf3cd5f | ||
|
|
34231907d0 | ||
|
|
840dab58cd | ||
|
|
d5ceca0663 | ||
|
|
8cf3422688 | ||
|
|
6f743fc4b6 | ||
|
|
991b133bff | ||
|
|
3b010043de | ||
|
|
088ea29e6e | ||
|
|
b8b135ff73 | ||
|
|
2872fdaf48 | ||
|
|
9853f83454 | ||
|
|
fd6e661203 | ||
|
|
c087f68d74 | ||
|
|
b6620f3dde | ||
|
|
3228c3e085 | ||
|
|
6cc5fd6d1e | ||
|
|
4f6d5e7074 | ||
|
|
6a999e1127 | ||
|
|
e3d89cec0c | ||
|
|
1b6e96a820 | ||
|
|
e38ccf4c2f | ||
|
|
010c801081 | ||
|
|
edc9272e55 | ||
|
|
405ca6be33 | ||
|
|
c06ea2271a | ||
|
|
0692e8b1e1 | ||
|
|
aa23356420 | ||
|
|
00a610e5ad | ||
|
|
2e39dcc0d3 | ||
|
|
03d3a26f6f | ||
|
|
309fa9cf51 | ||
|
|
65aab8adea | ||
|
|
3d48b287a3 | ||
|
|
29cebf0bec | ||
|
|
95a0f0bedc | ||
|
|
77e0617861 | ||
|
|
469a0405a1 | ||
|
|
46f191ffe7 | ||
|
|
ec7ac20def | ||
|
|
3f410b0b77 | ||
|
|
8e06cac0df | ||
|
|
e5099f4e74 | ||
|
|
447adef472 | ||
|
|
a849b05e5a | ||
|
|
b048f1b1de | ||
|
|
f7848f9560 | ||
|
|
236b56d285 | ||
|
|
42a717054a | ||
|
|
263166768e | ||
|
|
7a45b7efa7 | ||
|
|
54ed532e3e | ||
|
|
05e2028c5d | ||
|
|
79249063b8 | ||
|
|
31ebec7a72 | ||
|
|
919d399fdb | ||
|
|
32a7a1487d | ||
|
|
8c2671ce40 | ||
|
|
5d1005a7c8 | ||
|
|
b84f906964 | ||
|
|
7c0520d029 | ||
|
|
9d09121fbc | ||
|
|
7f2a5424d4 | ||
|
|
00830f0ecd | ||
|
|
fd7737af7d | ||
|
|
f2130c4c25 | ||
|
|
4f40683fd8 | ||
|
|
5fc9e53eec | ||
|
|
27e3cea285 | ||
|
|
ee770fa68f | ||
|
|
9cb4aa16eb | ||
|
|
92d990629f | ||
|
|
ba58f1bc0b | ||
|
|
02fcfd530f | ||
|
|
095e8a3de8 | ||
|
|
e17ad83fb5 | ||
|
|
e7c41151ec | ||
|
|
7f4ba62d4f | ||
|
|
71b17a3a53 | ||
|
|
d46b8b8fd7 | ||
|
|
a671070a28 | ||
|
|
4600d5351b | ||
|
|
75bba5b8e5 | ||
|
|
8d1d1536d3 | ||
|
|
a7050a185b | ||
|
|
d345541c2d | ||
|
|
bd028e4c66 | ||
|
|
d6f4fb67cc | ||
|
|
4378b540cf | ||
|
|
39ddb7c3e3 | ||
|
|
344cbd3286 | ||
|
|
d4ba173b53 | ||
|
|
c56ce656b2 | ||
|
|
9377214518 | ||
|
|
900a1c095f | ||
|
|
7e97a96840 | ||
|
|
69f272d7ba | ||
|
|
a653554bd9 | ||
|
|
6a25006544 | ||
|
|
8cfe4820f6 | ||
|
|
c8021d4224 | ||
|
|
3a64cc27b5 | ||
|
|
2edc485ec1 | ||
|
|
a6d6553cee | ||
|
|
45feef9413 | ||
|
|
105fe3961c | ||
|
|
d381c7b186 | ||
|
|
5e8334c0bf | ||
|
|
2ea8a16afb | ||
|
|
aa054db1c7 | ||
|
|
07d70a6a56 | ||
|
|
747572e62c | ||
|
|
72ed76e89e | ||
|
|
a403cb04f3 | ||
|
|
ed71184854 | ||
|
|
dfbf43e463 | ||
|
|
7d7d72dcfe | ||
|
|
540c036988 | ||
|
|
58f89ceec9 | ||
|
|
4e3a184199 | ||
|
|
22e4ae99e8 | ||
|
|
75ab786afc | ||
|
|
e5c72ba1f2 | ||
|
|
66873d7d64 | ||
|
|
a0d1d5bcea | ||
|
|
fa0fa95bb6 | ||
|
|
41ea2f811a | ||
|
|
ec352cfce2 | ||
|
|
aade874241 | ||
|
|
c01eb653d7 | ||
|
|
892f80c265 | ||
|
|
2e487a2c55 | ||
|
|
a34e3ba338 | ||
|
|
c414f4cb12 | ||
|
|
d91c603875 | ||
|
|
7f899dcfca | ||
|
|
5f12fd4346 | ||
|
|
a7197f846b | ||
|
|
ac81fa7a9f | ||
|
|
091df1f1e7 | ||
|
|
a9fbfa108f | ||
|
|
44a8bf4143 | ||
|
|
3da8aa257b | ||
|
|
884dd749a0 | ||
|
|
c697591d6e | ||
|
|
0b706e03e7 | ||
|
|
447e75cd06 | ||
|
|
7f76c8809c | ||
|
|
cde1f81df6 | ||
|
|
c21ed1e478 | ||
|
|
a8cb4a21d1 | ||
|
|
0b9e673fa2 | ||
|
|
d242af8e22 | ||
|
|
76bd931d79 | ||
|
|
995f3374f1 | ||
|
|
1887885274 | ||
|
|
ce43cf412d | ||
|
|
d1712f0594 | ||
|
|
416b73b8c0 | ||
|
|
4654aa0cab | ||
|
|
6f9d8f465a | ||
|
|
e5e55345dc | ||
|
|
8d6eb6d41a | ||
|
|
1118e67cec | ||
|
|
d70cd04b15 | ||
|
|
3d1db23224 | ||
|
|
a488810693 | ||
|
|
0b066d3cb4 | ||
|
|
d154bee18a | ||
|
|
3a8694b642 | ||
|
|
fe485b3fa1 | ||
|
|
e70eaa6a31 | ||
|
|
27ef67306d | ||
|
|
547aca3db2 | ||
|
|
5f7360e2ce | ||
|
|
23f9675218 | ||
|
|
ef1e82076c | ||
|
|
65d4588cc7 | ||
|
|
0488f90c8f | ||
|
|
03d91f6618 | ||
|
|
ae5e4b67dc | ||
|
|
a6c6e33d88 | ||
|
|
79d9bf7109 | ||
|
|
66e1b382cd | ||
|
|
66f1ff43e9 | ||
|
|
d6d14859e3 | ||
|
|
4478bb9bbe | ||
|
|
a6aaf9da2a | ||
|
|
aa908ae0c2 | ||
|
|
778a2d8f84 | ||
|
|
508baabf9a | ||
|
|
80aa4d8e19 | ||
|
|
99e11112a7 | ||
|
|
1116e6dbc7 | ||
|
|
d1ac96c1ab | ||
|
|
abe88c899e | ||
|
|
b1709fcbdb | ||
|
|
ec877bf490 | ||
|
|
a8f1812acf | ||
|
|
6877b460c4 | ||
|
|
f189f9f1be | ||
|
|
6f79fd6d77 | ||
|
|
60d7bb52d6 | ||
|
|
65a2a0643a | ||
|
|
bc5f151dfa | ||
|
|
5cd6ed0096 | ||
|
|
be84b35bfd | ||
|
|
d9fc30ffd0 | ||
|
|
8f59d00d9e | ||
|
|
3d8ff39aed | ||
|
|
b5c194df43 | ||
|
|
8680f92b60 | ||
|
|
05c97bc755 | ||
|
|
db88d60750 | ||
|
|
40c6da8075 | ||
|
|
3981b8084f | ||
|
|
9dfb7c1c37 | ||
|
|
9ed54c188e | ||
|
|
6a47a346b1 | ||
|
|
e3f8a576cf | ||
|
|
0aff733a92 | ||
|
|
9471bff8a4 | ||
|
|
3f8eea4687 | ||
|
|
b1b2d50c0d | ||
|
|
9c6607f78d | ||
|
|
2a4709e572 | ||
|
|
04f3fce3b0 | ||
|
|
be9c3524a5 | ||
|
|
c3d899dd48 | ||
|
|
6e03ee2a75 | ||
|
|
979a8814f1 | ||
|
|
8be4fad330 | ||
|
|
8113f95278 | ||
|
|
9ca6c646df | ||
|
|
466b37994e | ||
|
|
518c6d6ac3 | ||
|
|
9920b8d975 | ||
|
|
237daa2048 | ||
|
|
e9af28e6a3 | ||
|
|
996515c7ca | ||
|
|
c2ccc39e3c | ||
|
|
ad24b93431 | ||
|
|
bd5fc32d79 | ||
|
|
03cefe8f58 | ||
|
|
64339f7089 | ||
|
|
0b1704976a | ||
|
|
0af60b9c73 | ||
|
|
280f0eacc0 | ||
|
|
03cba5e59e | ||
|
|
fa0ea0e1a4 | ||
|
|
40d24b8907 | ||
|
|
1bf02f439f | ||
|
|
0489c62550 | ||
|
|
ad98602da3 | ||
|
|
fb12ac316a | ||
|
|
e9ec2f2706 | ||
|
|
00f294454b | ||
|
|
0465d940c7 | ||
|
|
2c549598d0 | ||
|
|
7d33082d70 |
29
.github/workflows/publish.yaml
vendored
Normal file
29
.github/workflows/publish.yaml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-20.04
|
||||
#if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
||||
278
README.md
278
README.md
@@ -1,92 +1,194 @@
|
||||
# DiffSynth Studio
|
||||
[](https://pypi.org/project/DiffSynth/)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
Welcome to the magic world of Diffusion models!
|
||||
|
||||
## Roadmap
|
||||
DiffSynth consists of two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||
|
||||
Until now, DiffSynth-Studio has supported the following models:
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||
|
||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
|
||||
- Paper: https://arxiv.org/abs/2412.12888
|
||||
- Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
|
||||
|
||||
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
|
||||
|
||||
- **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
|
||||
|
||||
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
|
||||
- Text to video
|
||||
- Video editing
|
||||
- Self-upscaling
|
||||
- Video interpolation
|
||||
|
||||
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
|
||||
- Use it in our [WebUI](#usage-in-webui).
|
||||
|
||||
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
|
||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||
- LoRA, ControlNet, and additional models will be available soon.
|
||||
|
||||
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
||||
|
||||
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||
|
||||
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
|
||||
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- The source codes are released in this project.
|
||||
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
|
||||
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||
|
||||
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
|
||||
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
- Demo videos are shown on Bilibili, including three tasks.
|
||||
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
|
||||
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
- Since OLSS requires additional training, we don't implement it in this project.
|
||||
|
||||
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
|
||||
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
|
||||
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
* Since OLSS requires additional training, we don't implement it in this project.
|
||||
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
|
||||
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
* Demo videos are shown on Bilibili, including three tasks.
|
||||
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
|
||||
* The source codes are released in this project.
|
||||
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
|
||||
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||
* Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
## Installation
|
||||
|
||||
Create Python environment:
|
||||
Install from source code (recommended):
|
||||
|
||||
```
|
||||
conda env create -f environment.yml
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
||||
|
||||
Enter the Python environment:
|
||||
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
|
||||
|
||||
```
|
||||
conda activate DiffSynthStudio
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
If you encounter issues during installation, it may be caused by the packages we depend on. Please refer to the documentation of the package that caused the problem.
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
## Usage (in Python code)
|
||||
|
||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||
|
||||
### Long Video Synthesis
|
||||
### Download Models
|
||||
|
||||
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
|
||||
|
||||
```python
|
||||
from diffsynth import download_models
|
||||
|
||||
download_models(["FLUX.1-dev", "Kolors"])
|
||||
```
|
||||
|
||||
Download your own models.
|
||||
|
||||
```python
|
||||
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
|
||||
|
||||
# From Modelscope (recommended)
|
||||
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
|
||||
# From Huggingface
|
||||
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
|
||||
```
|
||||
|
||||
### Video Synthesis
|
||||
|
||||
#### Text-to-video using CogVideoX-5B
|
||||
|
||||
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
|
||||
|
||||
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
#### Long Video Synthesis
|
||||
|
||||
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
### Image Synthesis
|
||||
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
||||
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|
||||
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Toon Shading
|
||||
#### Toon Shading
|
||||
|
||||
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
||||
|
||||
@@ -94,32 +196,60 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||
|
||||
### Video Stylization
|
||||
#### Video Stylization
|
||||
|
||||
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Chinese Models
|
||||
### Image Synthesis
|
||||
|
||||
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
||||
|
||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|
||||
|1024x1024|2048x2048 (highres-fix)|
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Usage (in WebUI)
|
||||
|
||||
Create stunning images using the painter, with assistance from AI!
|
||||
|
||||
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
|
||||
|
||||
**This video is not rendered in real-time.**
|
||||
|
||||
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
|
||||
|
||||
* `Gradio` version
|
||||
|
||||
```
|
||||
python -m streamlit run DiffSynth_Studio.py
|
||||
pip install gradio
|
||||
```
|
||||
|
||||
```
|
||||
python apps/gradio/DiffSynth_Studio.py
|
||||
```
|
||||
|
||||

|
||||
|
||||
* `Streamlit` version
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
```
|
||||
|
||||
```
|
||||
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
|
||||
```
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
||||
|
||||
252
apps/gradio/DiffSynth_Studio.py
Normal file
252
apps/gradio/DiffSynth_Studio.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import gradio as gr
|
||||
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
||||
import os, torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
config = {
|
||||
"model_config": {
|
||||
"Stable Diffusion": {
|
||||
"model_folder": "models/stable_diffusion",
|
||||
"pipeline_class": SDImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
},
|
||||
"Stable Diffusion XL": {
|
||||
"model_folder": "models/stable_diffusion_xl",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"Stable Diffusion 3": {
|
||||
"model_folder": "models/stable_diffusion_3",
|
||||
"pipeline_class": SD3ImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"Stable Diffusion XL Turbo": {
|
||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"default_parameters": {
|
||||
"negative_prompt": "",
|
||||
"cfg_scale": 1.0,
|
||||
"num_inference_steps": 1,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
},
|
||||
"Kolors": {
|
||||
"model_folder": "models/kolors",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"HunyuanDiT": {
|
||||
"model_folder": "models/HunyuanDiT",
|
||||
"pipeline_class": HunyuanDiTImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"FLUX": {
|
||||
"model_folder": "models/FLUX",
|
||||
"pipeline_class": FluxImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 1.0,
|
||||
}
|
||||
}
|
||||
},
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
|
||||
def load_model_list(model_type):
|
||||
if model_type is None:
|
||||
return []
|
||||
folder = config["model_config"][model_type]["model_folder"]
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
global model_dict
|
||||
model_key = f"{model_type}:{model_path}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||
model_manager = ModelManager()
|
||||
if model_type == "HunyuanDiT":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
])
|
||||
elif model_type == "Kolors":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "text_encoder"),
|
||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||
])
|
||||
elif model_type == "FLUX":
|
||||
model_manager.torch_dtype = torch.bfloat16
|
||||
file_list = [
|
||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
||||
os.path.join(model_path, "text_encoder_2"),
|
||||
]
|
||||
for file_name in os.listdir(model_path):
|
||||
if file_name.endswith(".safetensors"):
|
||||
file_list.append(os.path.join(model_path, file_name))
|
||||
model_manager.load_models(file_list)
|
||||
else:
|
||||
model_manager.load_model(model_path)
|
||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
||||
key = next(iter(model_dict.keys()))
|
||||
model_manager_to_release, _ = model_dict[key]
|
||||
model_manager_to_release.to("cpu")
|
||||
del model_dict[key]
|
||||
torch.cuda.empty_cache()
|
||||
model_dict[model_key] = model_manager, pipe
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
model_dict = {}
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# DiffSynth-Studio Painter")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
|
||||
with gr.Accordion(label="Model"):
|
||||
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
||||
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
||||
|
||||
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
||||
def model_type_to_model_path(model_type):
|
||||
return gr.Dropdown(choices=load_model_list(model_type))
|
||||
|
||||
with gr.Accordion(label="Prompt"):
|
||||
prompt = gr.Textbox(label="Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
||||
|
||||
with gr.Accordion(label="Image"):
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Column():
|
||||
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
||||
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
||||
triggers=model_path.change
|
||||
)
|
||||
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
||||
load_model(model_type, model_path)
|
||||
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
||||
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
||||
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
||||
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
||||
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
||||
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
||||
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Painter"):
|
||||
enable_local_prompt_list = []
|
||||
local_prompt_list = []
|
||||
mask_scale_list = []
|
||||
canvas_list = []
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
||||
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
||||
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
||||
label="Painter", key=f"canvas_{painter_layer_id}")
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
|
||||
enable_local_prompt_list.append(enable_local_prompt)
|
||||
local_prompt_list.append(local_prompt)
|
||||
mask_scale_list.append(mask_scale)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
output_to_input_button = gr.Button(value="Set as input image")
|
||||
painter_background = gr.State(None)
|
||||
input_background = gr.State(None)
|
||||
@gr.on(
|
||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
||||
outputs=[output_image],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
||||
_, pipe = load_model(model_type, model_path)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
if isinstance(pipe, FluxImagePipeline):
|
||||
input_params["embedded_guidance"] = embedded_guidance
|
||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
||||
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
||||
)
|
||||
local_prompts, masks, mask_scales = [], [], []
|
||||
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
||||
):
|
||||
if enable_local_prompt:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
mask_scales.append(mask_scale)
|
||||
input_params.update({
|
||||
"local_prompts": local_prompts,
|
||||
"masks": masks,
|
||||
"mask_scales": mask_scales,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(**input_params)
|
||||
return image
|
||||
|
||||
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(output_image, *canvas_list):
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = output_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
app.launch()
|
||||
390
apps/gradio/entity_level_control.py
Normal file
390
apps/gradio/entity_level_control.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
import json
|
||||
import gradio as gr
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
||||
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
||||
with open(example_json, 'r') as f:
|
||||
examples = json.load(f)['examples']
|
||||
|
||||
for idx in range(len(examples)):
|
||||
example_id = examples[idx]['example_id']
|
||||
entity_prompts = examples[idx]['local_prompt_list']
|
||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
def create_canvas_data(background, masks):
|
||||
if background.shape[-1] == 3:
|
||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
||||
layers = []
|
||||
for mask in masks:
|
||||
if mask is not None:
|
||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
||||
layer[..., -1] = mask_single_channel
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers.append(np.zeros_like(background))
|
||||
|
||||
composite = background.copy()
|
||||
for layer in layers:
|
||||
if layer.size > 0:
|
||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
||||
return {
|
||||
"background": background,
|
||||
"layers": layers,
|
||||
"composite": composite,
|
||||
}
|
||||
|
||||
def load_example(load_example_button):
|
||||
example_idx = int(load_example_button.split()[-1]) - 1
|
||||
example = examples[example_idx]
|
||||
result = [
|
||||
50,
|
||||
example["global_prompt"],
|
||||
example["negative_prompt"],
|
||||
example["seed"],
|
||||
*example["local_prompt_list"],
|
||||
]
|
||||
num_entities = len(example["local_prompt_list"])
|
||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
||||
masks = []
|
||||
for mask in example["mask_lists"]:
|
||||
mask_single_channel = np.array(mask.convert("L"))
|
||||
masks.append(mask_single_channel)
|
||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
||||
masks.append(blank_mask)
|
||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
||||
canvas_data_list = []
|
||||
for mask in masks:
|
||||
canvas_data = create_canvas_data(background, [mask])
|
||||
canvas_data_list.append(canvas_data)
|
||||
result.extend(canvas_data_list)
|
||||
return result
|
||||
|
||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||
print(f'save to {save_dir}')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, mask in enumerate(masks):
|
||||
save_path = os.path.join(save_dir, f'{i}.png')
|
||||
mask.save(save_path)
|
||||
sample = {
|
||||
"global_prompt": global_prompt,
|
||||
"mask_prompts": mask_prompts,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
||||
json.dump(sample, f, indent=4)
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
if mask is None:
|
||||
continue
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
if mask_bbox is None:
|
||||
continue
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
return result
|
||||
|
||||
config = {
|
||||
"model_config": {
|
||||
"FLUX": {
|
||||
"model_folder": "models/FLUX",
|
||||
"pipeline_class": FluxImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 3.0,
|
||||
"embedded_guidance": 3.5,
|
||||
"num_inference_steps": 30,
|
||||
}
|
||||
},
|
||||
},
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
model_dict = {}
|
||||
|
||||
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
||||
global model_dict
|
||||
model_key = f"{model_type}:{model_path}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control",
|
||||
),
|
||||
lora_alpha=1,
|
||||
)
|
||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
model_dict[model_key] = model_manager, pipe
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
||||
"""
|
||||
)
|
||||
|
||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
||||
main_interface = gr.Column(visible=False)
|
||||
|
||||
def initialize_model():
|
||||
try:
|
||||
load_model()
|
||||
return {
|
||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f'Failed to load model with error: {e}')
|
||||
return {
|
||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
|
||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
||||
|
||||
with main_interface:
|
||||
with gr.Row():
|
||||
local_prompt_list = []
|
||||
canvas_list = []
|
||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
model_type = gr.State('FLUX')
|
||||
model_path = gr.State('FLUX.1-dev')
|
||||
with gr.Accordion(label="Global prompt"):
|
||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
||||
with gr.Accordion(label="Inference Options", open=True):
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Accordion(label="Inpaint Input Image", open=False):
|
||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
||||
|
||||
with gr.Column():
|
||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
||||
def reset_input_image(input_image):
|
||||
return None
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Entity Painter"):
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(
|
||||
canvas_size=(512, 512),
|
||||
sources=None,
|
||||
layers=False,
|
||||
interactive=True,
|
||||
image_mode="RGBA",
|
||||
brush=gr.Brush(
|
||||
default_size=50,
|
||||
default_color="#000000",
|
||||
colors=["#000000"],
|
||||
),
|
||||
label="Entity Mask Painter",
|
||||
key=f"canvas_{painter_layer_id}",
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
local_prompt_list.append(local_prompt)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
||||
real_output = gr.State(None)
|
||||
mask_out = gr.State(None)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
||||
outputs=[output_image, real_output, mask_out],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||
_, pipe = load_model(model_type, model_path)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
if isinstance(pipe, FluxImagePipeline):
|
||||
input_params["embedded_guidance"] = embedded_guidance
|
||||
if input_image is not None:
|
||||
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||
input_params["enable_eligen_inpaint"] = True
|
||||
|
||||
local_prompt_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
)
|
||||
local_prompts, masks = [], []
|
||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
entity_masks = None if len(masks) == 0 else masks
|
||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||
input_params.update({
|
||||
"eligen_entity_prompts": entity_prompts,
|
||||
"eligen_entity_masks": entity_masks,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
||||
image = pipe(**input_params)
|
||||
masks = [mask.resize(image.size) for mask in masks]
|
||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
||||
|
||||
real_output = gr.State(image)
|
||||
mask_out = gr.State(image_with_mask)
|
||||
|
||||
if return_with_mask:
|
||||
return image_with_mask, real_output, mask_out
|
||||
return image, real_output, mask_out
|
||||
|
||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
||||
def send_input_to_painter_background(input_image, *canvas_list):
|
||||
if input_image is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = input_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(real_output, *canvas_list):
|
||||
if real_output is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = real_output.value.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
||||
def show_output(return_with_mask, real_output, mask_out):
|
||||
if return_with_mask:
|
||||
return mask_out.value
|
||||
else:
|
||||
return real_output.value
|
||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
||||
def send_output_to_pipe_input(real_output):
|
||||
return real_output.value
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## Examples")
|
||||
for i in range(0, len(examples), 2):
|
||||
with gr.Row():
|
||||
if i < len(examples):
|
||||
example = examples[i]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
|
||||
if i + 1 < len(examples):
|
||||
example = examples[i + 1]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
app.config["show_progress"] = "hidden"
|
||||
app.launch()
|
||||
@@ -1,7 +1,7 @@
|
||||
# Set web page format
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
# Diasble virtual VRAM on windows system
|
||||
# Disable virtual VRAM on windows system
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import torch, os, io
|
||||
import torch, os, io, json, time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, HunyuanDiTImagePipeline
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
||||
from diffsynth.data.video import crop_and_resize
|
||||
|
||||
|
||||
@@ -20,6 +20,11 @@ config = {
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion 3": {
|
||||
"model_folder": "models/stable_diffusion_3",
|
||||
"pipeline_class": SD3ImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion XL Turbo": {
|
||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
@@ -31,6 +36,11 @@ config = {
|
||||
"width": 512,
|
||||
}
|
||||
},
|
||||
"Kolors": {
|
||||
"model_folder": "models/kolors",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"HunyuanDiT": {
|
||||
"model_folder": "models/HunyuanDiT",
|
||||
"pipeline_class": HunyuanDiTImagePipeline,
|
||||
@@ -39,13 +49,20 @@ config = {
|
||||
"width": 1024,
|
||||
}
|
||||
},
|
||||
"FLUX": {
|
||||
"model_folder": "models/FLUX",
|
||||
"pipeline_class": FluxImagePipeline,
|
||||
"fixed_parameters": {
|
||||
"cfg_scale": 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_model_list(model_type):
|
||||
folder = config[model_type]["model_folder"]
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||
if model_type == "HunyuanDiT":
|
||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
@@ -69,6 +86,22 @@ def load_model(model_type, model_path):
|
||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
])
|
||||
elif model_type == "Kolors":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "text_encoder"),
|
||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||
])
|
||||
elif model_type == "FLUX":
|
||||
model_manager.torch_dtype = torch.bfloat16
|
||||
file_list = [
|
||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
||||
os.path.join(model_path, "text_encoder_2"),
|
||||
]
|
||||
for file_name in os.listdir(model_path):
|
||||
if file_name.endswith(".safetensors"):
|
||||
file_list.append(os.path.join(model_path, file_name))
|
||||
model_manager.load_models(file_list)
|
||||
else:
|
||||
model_manager.load_model(model_path)
|
||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
@@ -239,6 +272,48 @@ with column_input:
|
||||
key="canvas"
|
||||
)
|
||||
|
||||
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
||||
local_prompts, masks, mask_scales = [], [], []
|
||||
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
||||
painter_layers_json_data = []
|
||||
for painter_tab_id in range(num_painter_layer):
|
||||
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
||||
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
||||
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
||||
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
||||
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
||||
canvas_result_local = st_canvas(
|
||||
fill_color="#000000",
|
||||
stroke_width=stroke_width,
|
||||
stroke_color="#000000",
|
||||
background_color="rgba(255, 255, 255, 0)",
|
||||
background_image=white_board,
|
||||
update_streamlit=True,
|
||||
height=512,
|
||||
width=512,
|
||||
drawing_mode="freedraw",
|
||||
key=f"canvas_{painter_tab_id}"
|
||||
)
|
||||
if canvas_result_local.json_data is not None:
|
||||
painter_layers_json_data.append(canvas_result_local.json_data.copy())
|
||||
painter_layers_json_data[-1]["prompt"] = local_prompt
|
||||
if enable_local_prompt:
|
||||
local_prompts.append(local_prompt)
|
||||
if canvas_result_local.image_data is not None:
|
||||
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
||||
else:
|
||||
mask = white_board
|
||||
mask = Image.fromarray(255 - np.array(mask))
|
||||
masks.append(mask)
|
||||
mask_scales.append(mask_scale)
|
||||
save_painter_layers = st.button("Save painter layers")
|
||||
if save_painter_layers:
|
||||
os.makedirs("data/painter_layers", exist_ok=True)
|
||||
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
|
||||
with open(json_file_path, "w") as f:
|
||||
json.dump(painter_layers_json_data, f, indent=4)
|
||||
st.markdown(f"Painter layers are saved in {json_file_path}.")
|
||||
|
||||
|
||||
with column_output:
|
||||
run_button = st.button("Generate image", type="primary")
|
||||
@@ -266,6 +341,7 @@ with column_output:
|
||||
progress_bar_st = st.progress(0.0)
|
||||
image = pipeline(
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
||||
height=height, width=width,
|
||||
input_image=input_image, denoising_strength=denoising_strength,
|
||||
@@ -1,6 +1,6 @@
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .prompts import *
|
||||
from .prompters import *
|
||||
from .schedulers import *
|
||||
from .pipelines import *
|
||||
from .controlnets import *
|
||||
|
||||
0
diffsynth/configs/__init__.py
Normal file
0
diffsynth/configs/__init__.py
Normal file
800
diffsynth/configs/model_config.py
Normal file
800
diffsynth/configs/model_config.py
Normal file
@@ -0,0 +1,800 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
from ..models.sd_text_encoder import SDTextEncoder
|
||||
from ..models.sd_unet import SDUNet
|
||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
||||
|
||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from ..models.sdxl_unet import SDXLUNet
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from ..models.sd3_dit import SD3DiT
|
||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from ..models.sd_controlnet import SDControlNet
|
||||
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
||||
|
||||
from ..models.sd_motion import SDMotionModel
|
||||
from ..models.sdxl_motion import SDXLMotionModel
|
||||
|
||||
from ..models.svd_image_encoder import SVDImageEncoder
|
||||
from ..models.svd_unet import SVDUNet
|
||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
|
||||
from ..models.flux_dit import FluxDiT
|
||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
|
||||
from ..models.omnigen import OmniGenTransformer
|
||||
|
||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
||||
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
||||
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
||||
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
||||
]
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
"HunyuanDiT": [
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"DreamShaper_8": [
|
||||
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
# Textual Inversion
|
||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||
],
|
||||
# Stable Diffusion XL
|
||||
"StableDiffusionXL_v1": [
|
||||
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"BluePencilXL_v200": [
|
||||
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
"StableDiffusion3_without_T5": [
|
||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
# ControlNet
|
||||
"ControlNet_v11f1p_sd15_depth": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11p_sd15_softedge": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11f1e_sd15_tile": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||
],
|
||||
"ControlNet_v11p_sd15_lineart": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_union_sdxl_promax": [
|
||||
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
"AnimateDiff_xl_beta": [
|
||||
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
|
||||
# Qwen Prompt
|
||||
"QwenPrompt": [
|
||||
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
],
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": [
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
# Omost prompt
|
||||
"OmostPrompt":[
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
],
|
||||
# Translator
|
||||
"opus-mt-zh-en": [
|
||||
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||
],
|
||||
"IP-Adapter-SDXL": [
|
||||
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
# FLUX
|
||||
"FLUX.1-dev": [
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||
"file_list": [
|
||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# CogVideo
|
||||
"CogVideoX-5B": [
|
||||
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||
],
|
||||
# Stable Diffusion 3.5
|
||||
"StableDiffusion3.5-large": [
|
||||
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
}
|
||||
preset_models_on_modelscope = {
|
||||
# Hunyuan DiT
|
||||
"HunyuanDiT": [
|
||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
# Stable Video Diffusion
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# ExVideo
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
||||
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"DreamShaper_8": [
|
||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"AingDiffusion_v12": [
|
||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"Flat2DAnimerge_v45Sharp": [
|
||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
# Textual Inversion
|
||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||
],
|
||||
# Stable Diffusion XL
|
||||
"StableDiffusionXL_v1": [
|
||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"BluePencilXL_v200": [
|
||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
||||
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
"StableDiffusion3_without_T5": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
# ControlNet
|
||||
"ControlNet_v11f1p_sd15_depth": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11p_sd15_softedge": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11f1e_sd15_tile": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||
],
|
||||
"ControlNet_v11p_sd15_lineart": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_union_sdxl_promax": [
|
||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"Annotators:Depth": [
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Softedge": [
|
||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Lineart": [
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Normal": [
|
||||
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Openpose": [
|
||||
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
"AnimateDiff_xl_beta": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Qwen Prompt
|
||||
"QwenPrompt": {
|
||||
"file_list": [
|
||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/QwenPrompt/qwen2-1.5b-instruct",
|
||||
],
|
||||
},
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": {
|
||||
"file_list": [
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
||||
],
|
||||
},
|
||||
# Omost prompt
|
||||
"OmostPrompt": {
|
||||
"file_list": [
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
||||
],
|
||||
},
|
||||
# Translator
|
||||
"opus-mt-zh-en": {
|
||||
"file_list": [
|
||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/translator/opus-mt-zh-en",
|
||||
],
|
||||
},
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||
],
|
||||
"IP-Adapter-SDXL": [
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": {
|
||||
"file_list": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/kolors/Kolors/text_encoder",
|
||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
},
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
# FLUX
|
||||
"FLUX.1-dev": {
|
||||
"file_list": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
],
|
||||
},
|
||||
"FLUX.1-schnell": {
|
||||
"file_list": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
||||
],
|
||||
},
|
||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
||||
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
||||
],
|
||||
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
||||
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
||||
],
|
||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
||||
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
||||
],
|
||||
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
||||
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
||||
],
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
||||
],
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
||||
],
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
||||
],
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
||||
],
|
||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||
"file_list": [
|
||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
"InfiniteYou":{
|
||||
"file_list":[
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
],
|
||||
"load_path":[
|
||||
[
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||
],
|
||||
"models/InfiniteYou/image_proj_model.bin",
|
||||
],
|
||||
},
|
||||
# ESRGAN
|
||||
"ESRGAN_x4": [
|
||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||
],
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Omnigen
|
||||
"OmniGen-v1": {
|
||||
"file_list": [
|
||||
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
||||
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
||||
"models/OmniGen/OmniGen-v1/model.safetensors",
|
||||
]
|
||||
},
|
||||
# CogVideo
|
||||
"CogVideoX-5B": {
|
||||
"file_list": [
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/CogVideo/CogVideoX-5b/text_encoder",
|
||||
"models/CogVideo/CogVideoX-5b/transformer",
|
||||
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
},
|
||||
# Stable Diffusion 3.5
|
||||
"StableDiffusion3.5-large": [
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"StableDiffusion3.5-medium": [
|
||||
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"StableDiffusion3.5-large-turbo": [
|
||||
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"HunyuanVideo":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideoI2V":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideo-fp8":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
||||
],
|
||||
},
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
"stable-video-diffusion-img2vid-xt",
|
||||
"ExVideo-SVD-128f-v1",
|
||||
"ExVideo-CogVideoX-LoRA-129f-v1",
|
||||
"StableDiffusion_v15",
|
||||
"DreamShaper_8",
|
||||
"AingDiffusion_v12",
|
||||
"Flat2DAnimerge_v45Sharp",
|
||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
||||
"StableDiffusionXL_v1",
|
||||
"BluePencilXL_v200",
|
||||
"StableDiffusionXL_Turbo",
|
||||
"ControlNet_v11f1p_sd15_depth",
|
||||
"ControlNet_v11p_sd15_softedge",
|
||||
"ControlNet_v11f1e_sd15_tile",
|
||||
"ControlNet_v11p_sd15_lineart",
|
||||
"AnimateDiff_v2",
|
||||
"AnimateDiff_xl_beta",
|
||||
"RIFE",
|
||||
"BeautifulPrompt",
|
||||
"opus-mt-zh-en",
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
"SDXL-vae-fp16-fix",
|
||||
"ControlNet_union_sdxl_promax",
|
||||
"FLUX.1-dev",
|
||||
"FLUX.1-schnell",
|
||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
||||
"jasperai/Flux.1-dev-Controlnet-Depth",
|
||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||
"InfiniteYou",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
"QwenPrompt",
|
||||
"OmostPrompt",
|
||||
"ESRGAN_x4",
|
||||
"RIFE",
|
||||
"OmniGen-v1",
|
||||
"CogVideoX-5B",
|
||||
"Annotators:Depth",
|
||||
"Annotators:Softedge",
|
||||
"Annotators:Lineart",
|
||||
"Annotators:Normal",
|
||||
"Annotators:Openpose",
|
||||
"StableDiffusion3.5-large",
|
||||
"StableDiffusion3.5-medium",
|
||||
"HunyuanVideo",
|
||||
"HunyuanVideo-fp8",
|
||||
"HunyuanVideoI2V",
|
||||
]
|
||||
@@ -1,2 +1,2 @@
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
||||
from .processors import Annotator
|
||||
|
||||
@@ -4,10 +4,11 @@ from .processors import Processor_id
|
||||
|
||||
|
||||
class ControlNetConfigUnit:
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
||||
self.processor_id = processor_id
|
||||
self.model_path = model_path
|
||||
self.scale = scale
|
||||
self.skip_processor = skip_processor
|
||||
|
||||
|
||||
class ControlNetUnit:
|
||||
@@ -23,6 +24,16 @@ class MultiControlNetManager:
|
||||
self.models = [unit.model for unit in controlnet_units]
|
||||
self.scales = [unit.scale for unit in controlnet_units]
|
||||
|
||||
def cpu(self):
|
||||
for model in self.models:
|
||||
model.cpu()
|
||||
|
||||
def to(self, device):
|
||||
for model in self.models:
|
||||
model.to(device)
|
||||
for processor in self.processors:
|
||||
processor.to(device)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
@@ -37,13 +48,14 @@ class MultiControlNetManager:
|
||||
def __call__(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditionings,
|
||||
tiled=False, tile_size=64, tile_stride=32
|
||||
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
||||
):
|
||||
res_stack = None
|
||||
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||
res_stack_ = model(
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
processor_id=processor.processor_id
|
||||
)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
if res_stack is None:
|
||||
@@ -51,3 +63,29 @@ class MultiControlNetManager:
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
|
||||
|
||||
class FluxMultiControlNetManager(MultiControlNetManager):
|
||||
def __init__(self, controlnet_units=[]):
|
||||
super().__init__(controlnet_units=controlnet_units)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
else:
|
||||
processed_image = [self.processors[processor_id](image)]
|
||||
return processed_image
|
||||
|
||||
def __call__(self, conditionings, **kwargs):
|
||||
res_stack, single_res_stack = None, None
|
||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
single_res_stack_ = [res * scale for res in single_res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
single_res_stack = single_res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
||||
return res_stack, single_res_stack
|
||||
|
||||
@@ -1,39 +1,50 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||
if not skip_processor:
|
||||
if processor_id == "canny":
|
||||
from controlnet_aux.processor import CannyDetector
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
from controlnet_aux.processor import MidasDetector
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
from controlnet_aux.processor import HEDdetector
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
from controlnet_aux.processor import LineartDetector
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
from controlnet_aux.processor import LineartAnimeDetector
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
from controlnet_aux.processor import OpenposeDetector
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "normal":
|
||||
from controlnet_aux.processor import NormalBaeDetector
|
||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
|
||||
self.processor = None
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
def to(self,device):
|
||||
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
|
||||
|
||||
def __call__(self, image):
|
||||
self.processor.model.to(device)
|
||||
|
||||
def __call__(self, image, mask=None):
|
||||
width, height = image.size
|
||||
if self.processor_id == "openpose":
|
||||
kwargs = {
|
||||
|
||||
41
diffsynth/data/simple_text_image.py
Normal file
41
diffsynth/data/simple_text_image.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch, os, torchvision
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
target_height, target_width = self.height, self.width
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
shape = [round(height*scale),round(width*scale)]
|
||||
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
@@ -135,8 +135,8 @@ class VideoData:
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
|
||||
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
s_per_rank = x.shape[1]
|
||||
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
def usp_dit_forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
x = torch.chunk(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
x = self.head(x, t)
|
||||
|
||||
# Context Parallel
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
)
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
return self.o(x)
|
||||
@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
|
||||
|
||||
class RRDBNet(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
||||
@@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module):
|
||||
feat = self.lrelu(self.conv_up2(feat))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return RRDBNetStateDictConverter()
|
||||
|
||||
|
||||
class RRDBNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class ESRGAN(torch.nn.Module):
|
||||
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
|
||||
self.model = model
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_path):
|
||||
model = RRDBNet()
|
||||
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return ESRGAN(model)
|
||||
def from_model_manager(model_manager):
|
||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
||||
|
||||
def process_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
||||
@@ -96,6 +107,12 @@ class ESRGAN(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
is_single_image = True
|
||||
else:
|
||||
is_single_image = False
|
||||
|
||||
# Preprocess
|
||||
input_tensor = self.process_images(images)
|
||||
|
||||
@@ -115,4 +132,6 @@ class ESRGAN(torch.nn.Module):
|
||||
|
||||
# To images
|
||||
output_images = self.decode_images(output_tensor)
|
||||
if is_single_image:
|
||||
output_images = output_images[0]
|
||||
return output_images
|
||||
|
||||
1
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
Normal file
1
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .blip_pretrain import *
|
||||
77
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
Normal file
77
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
Normal file
@@ -0,0 +1,77 @@
|
||||
'''
|
||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||
'''
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import torch
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from timm.models.hub import download_cached_file
|
||||
from transformers import BertTokenizer
|
||||
from .vit import VisionTransformer, interpolate_pos_embed
|
||||
|
||||
|
||||
def default_bert():
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
return os.path.join(model_path, "bert-base-uncased")
|
||||
|
||||
|
||||
def init_tokenizer(bert_model_path):
|
||||
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||
return tokenizer
|
||||
|
||||
|
||||
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
||||
|
||||
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
||||
if vit=='base':
|
||||
vision_width = 768
|
||||
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
||||
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||
drop_path_rate=0 or drop_path_rate
|
||||
)
|
||||
elif vit=='large':
|
||||
vision_width = 1024
|
||||
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
||||
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||
drop_path_rate=0.1 or drop_path_rate
|
||||
)
|
||||
return visual_encoder, vision_width
|
||||
|
||||
|
||||
def is_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
def load_checkpoint(model,url_or_filename):
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||
checkpoint = torch.load(cached_file, map_location='cpu')
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
||||
else:
|
||||
raise RuntimeError('checkpoint url or path is invalid')
|
||||
|
||||
state_dict = checkpoint['model']
|
||||
|
||||
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
||||
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
||||
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
||||
model.visual_encoder_m)
|
||||
for key in model.state_dict().keys():
|
||||
if key in state_dict.keys():
|
||||
if state_dict[key].shape!=model.state_dict()[key].shape:
|
||||
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
||||
del state_dict[key]
|
||||
|
||||
msg = model.load_state_dict(state_dict,strict=False)
|
||||
print('load checkpoint from %s'%url_or_filename)
|
||||
return model,msg
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
'''
|
||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||
'''
|
||||
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
from torch import nn
|
||||
import os
|
||||
from .med import BertConfig, BertModel
|
||||
from .blip import create_vit, init_tokenizer
|
||||
|
||||
class BLIP_Pretrain(nn.Module):
|
||||
def __init__(self,
|
||||
med_config = "med_config.json",
|
||||
image_size = 224,
|
||||
vit = 'base',
|
||||
vit_grad_ckpt = False,
|
||||
vit_ckpt_layer = 0,
|
||||
embed_dim = 256,
|
||||
queue_size = 57600,
|
||||
momentum = 0.995,
|
||||
bert_model_path = ""
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
||||
|
||||
self.tokenizer = init_tokenizer(bert_model_path)
|
||||
encoder_config = BertConfig.from_json_file(med_config)
|
||||
encoder_config.encoder_width = vision_width
|
||||
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
||||
|
||||
text_width = self.text_encoder.config.hidden_size
|
||||
|
||||
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
||||
self.text_proj = nn.Linear(text_width, embed_dim)
|
||||
|
||||
947
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
Normal file
947
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
Normal file
@@ -0,0 +1,947 @@
|
||||
'''
|
||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||
* Based on huggingface code base
|
||||
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
||||
'''
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, nn
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.file_utils import (
|
||||
ModelOutput,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
NextSentencePredictorOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word and position embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
):
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
embeddings = inputs_embeds
|
||||
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings += position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
if is_cross_attention:
|
||||
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
||||
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
||||
else:
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
self.save_attention = False
|
||||
|
||||
def save_attn_gradients(self, attn_gradients):
|
||||
self.attn_gradients = attn_gradients
|
||||
|
||||
def get_attn_gradients(self):
|
||||
return self.attn_gradients
|
||||
|
||||
def save_attention_map(self, attention_map):
|
||||
self.attention_map = attention_map
|
||||
|
||||
def get_attention_map(self):
|
||||
return self.attention_map
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
seq_length = hidden_states.size()[1]
|
||||
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
if is_cross_attention and self.save_attention:
|
||||
self.save_attention_map(attention_probs)
|
||||
attention_probs.register_hook(self.save_attn_gradients)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs_dropped = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs_dropped = attention_probs_dropped * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention=False):
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(config, is_cross_attention)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config, layer_num):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = BertAttention(config)
|
||||
self.layer_num = layer_num
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
mode=None,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
|
||||
if mode=='multimodal':
|
||||
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
||||
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
mode='multimodal',
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
layer_module = self.layer[i]
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, past_key_value, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
mode=mode,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.predictions = BertLMPredictionHead(config)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class BertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "bert"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
||||
input to the forward pass.
|
||||
"""
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
|
||||
self.encoder = BertEncoder(config)
|
||||
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask (:obj:`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (:obj:`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
device: (:obj:`torch.device`):
|
||||
The device of the input to the model.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
||||
"""
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if is_decoder:
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||
# causal and attention masks must have same type with pytorch version < 1.3
|
||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||
|
||||
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||
causal_mask = torch.cat(
|
||||
[
|
||||
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
||||
causal_mask,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
is_decoder=False,
|
||||
mode='multimodal',
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = inputs_embeds.device
|
||||
elif encoder_embeds is not None:
|
||||
input_shape = encoder_embeds.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = encoder_embeds.device
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
||||
device, is_decoder)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_hidden_states is not None:
|
||||
if type(encoder_hidden_states) == list:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||
else:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
|
||||
if type(encoder_attention_mask) == list:
|
||||
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||
elif encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if encoder_embeds is None:
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
else:
|
||||
embedding_output = encoder_embeds
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
mode=mode,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class BertLMHeadModel(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.bert = BertModel(config, add_pooling_layer=False)
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
return_logits=False,
|
||||
is_decoder=True,
|
||||
reduction='mean',
|
||||
mode='multimodal',
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
Returns:
|
||||
Example::
|
||||
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
||||
>>> import torch
|
||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
||||
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> prediction_logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
is_decoder=is_decoder,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
if return_logits:
|
||||
return prediction_scores[:, :-1, :].contiguous()
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
if reduction=='none':
|
||||
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
||||
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
||||
"is_decoder": True,
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
301
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
Normal file
301
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
Normal file
@@ -0,0 +1,301 @@
|
||||
'''
|
||||
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||
* Based on timm code base
|
||||
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import _cfg, PatchEmbed
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_, DropPath
|
||||
from timm.models.helpers import named_apply, adapt_input_conv
|
||||
|
||||
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
"""
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.attn_gradients = None
|
||||
self.attention_map = None
|
||||
|
||||
def save_attn_gradients(self, attn_gradients):
|
||||
self.attn_gradients = attn_gradients
|
||||
|
||||
def get_attn_gradients(self):
|
||||
return self.attn_gradients
|
||||
|
||||
def save_attention_map(self, attention_map):
|
||||
self.attention_map = attention_map
|
||||
|
||||
def get_attention_map(self):
|
||||
return self.attention_map
|
||||
|
||||
def forward(self, x, register_hook=False):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
if register_hook:
|
||||
self.save_attention_map(attn)
|
||||
attn.register_hook(self.save_attn_gradients)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
# if use_grad_checkpointing:
|
||||
# self.attn = checkpoint_wrapper(self.attn)
|
||||
# self.mlp = checkpoint_wrapper(self.mlp)
|
||||
|
||||
def forward(self, x, register_hook=False):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
||||
https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
||||
use_grad_checkpointing=False, ckpt_layer=0):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def forward(self, x, register_blk=-1):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = x + self.pos_embed[:,:x.size(1),:]
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for i,blk in enumerate(self.blocks):
|
||||
x = blk(x, register_blk==i)
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
@torch.jit.ignore()
|
||||
def load_pretrained(self, checkpoint_path, prefix=''):
|
||||
_load_weights(self, checkpoint_path, prefix)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
def _n2p(w, t=True):
|
||||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||
w = w.flatten()
|
||||
if t:
|
||||
if w.ndim == 4:
|
||||
w = w.transpose([3, 2, 0, 1])
|
||||
elif w.ndim == 3:
|
||||
w = w.transpose([2, 0, 1])
|
||||
elif w.ndim == 2:
|
||||
w = w.transpose([1, 0])
|
||||
return torch.from_numpy(w)
|
||||
|
||||
w = np.load(checkpoint_path)
|
||||
if not prefix and 'opt/target/embedding/kernel' in w:
|
||||
prefix = 'opt/target/'
|
||||
|
||||
if hasattr(model.patch_embed, 'backbone'):
|
||||
# hybrid
|
||||
backbone = model.patch_embed.backbone
|
||||
stem_only = not hasattr(backbone, 'stem')
|
||||
stem = backbone if stem_only else backbone.stem
|
||||
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
||||
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
||||
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
||||
if not stem_only:
|
||||
for i, stage in enumerate(backbone.stages):
|
||||
for j, block in enumerate(stage.blocks):
|
||||
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
||||
for r in range(3):
|
||||
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
||||
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
||||
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
||||
if block.downsample is not None:
|
||||
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
||||
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
||||
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
||||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
||||
else:
|
||||
embed_conv_w = adapt_input_conv(
|
||||
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||
if pos_embed_w.shape != model.pos_embed.shape:
|
||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||
model.pos_embed.copy_(pos_embed_w)
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||
for i, block in enumerate(model.blocks.children()):
|
||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||
block.attn.qkv.weight.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||
block.attn.qkv.bias.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||
for r in range(2):
|
||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||
|
||||
|
||||
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
||||
# interpolate position embedding
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = visual_encoder.patch_embed.num_patches
|
||||
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
|
||||
if orig_size!=new_size:
|
||||
# class_token and dist_token are kept unchanged
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
||||
|
||||
return new_pos_embed
|
||||
else:
|
||||
return pos_embed_checkpoint
|
||||
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from modelscope import snapshot_download
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import os
|
||||
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
||||
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
||||
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
||||
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
||||
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
||||
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
||||
|
||||
|
||||
preference_model_id: TypeAlias = Literal[
|
||||
"ImageReward",
|
||||
"Aesthetic",
|
||||
"PickScore",
|
||||
"CLIP",
|
||||
"HPSv2",
|
||||
"HPSv2.1",
|
||||
"MPS",
|
||||
]
|
||||
model_dict = {
|
||||
"ImageReward": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"ImageReward/ImageReward.safetensors",
|
||||
"ImageReward/med_config.json",
|
||||
"bert-base-uncased/config.json",
|
||||
"bert-base-uncased/model.safetensors",
|
||||
"bert-base-uncased/tokenizer.json",
|
||||
"bert-base-uncased/tokenizer_config.json",
|
||||
"bert-base-uncased/vocab.txt",
|
||||
],
|
||||
"load_path": {
|
||||
"imagereward": "ImageReward/ImageReward.safetensors",
|
||||
"med_config": "ImageReward/med_config.json",
|
||||
"bert_model_path": "bert-base-uncased",
|
||||
},
|
||||
"model_class": ImageRewardScore
|
||||
},
|
||||
"Aesthetic": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||
"clip-vit-large-patch14/config.json",
|
||||
"clip-vit-large-patch14/merges.txt",
|
||||
"clip-vit-large-patch14/model.safetensors",
|
||||
"clip-vit-large-patch14/preprocessor_config.json",
|
||||
"clip-vit-large-patch14/special_tokens_map.json",
|
||||
"clip-vit-large-patch14/tokenizer.json",
|
||||
"clip-vit-large-patch14/tokenizer_config.json",
|
||||
"clip-vit-large-patch14/vocab.json",
|
||||
],
|
||||
"load_path": {
|
||||
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||
"clip-large": "clip-vit-large-patch14",
|
||||
},
|
||||
"model_class": AestheticScore
|
||||
},
|
||||
"PickScore": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"PickScore_v1/*",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||
],
|
||||
"load_path": {
|
||||
"pickscore": "PickScore_v1",
|
||||
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||
},
|
||||
"model_class": PickScore
|
||||
},
|
||||
"CLIP": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
],
|
||||
"load_path": {
|
||||
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||
},
|
||||
"model_class": CLIPScore
|
||||
},
|
||||
"HPSv2": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"HPS_v2/HPS_v2_compressed.safetensors",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
],
|
||||
"load_path": {
|
||||
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||
},
|
||||
"model_class": HPScore_v2,
|
||||
"extra_kwargs": {"model_version": "v2"}
|
||||
},
|
||||
"HPSv2.1": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
],
|
||||
"load_path": {
|
||||
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||
},
|
||||
"model_class": HPScore_v2,
|
||||
"extra_kwargs": {"model_version": "v21"}
|
||||
},
|
||||
"MPS": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||
],
|
||||
"load_path": {
|
||||
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||
},
|
||||
"model_class": MPScore
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
||||
metadata = model_dict[model_name]
|
||||
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
||||
load_path = metadata["load_path"]
|
||||
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
||||
return load_path
|
||||
|
||||
|
||||
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
||||
model_class = model_dict[model_name]["model_class"]
|
||||
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
||||
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
||||
return preference_model
|
||||
148
diffsynth/extensions/ImageQualityMetric/aesthetic.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/aesthetic.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from typing import List, Optional
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
from safetensors.torch import load_file
|
||||
import os
|
||||
from typing import Union, List
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.xcol = xcol
|
||||
self.ycol = ycol
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.input_size, 1024),
|
||||
#torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(1024, 128),
|
||||
#torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(128, 64),
|
||||
#torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.1),
|
||||
torch.nn.Linear(64, 16),
|
||||
#torch.nn.ReLU(),
|
||||
torch.nn.Linear(16, 1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers(x)
|
||||
|
||||
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
||||
x = batch[self.xcol]
|
||||
y = batch[self.ycol].reshape(-1, 1)
|
||||
x_hat = self.layers(x)
|
||||
loss = torch.nn.functional.mse_loss(x_hat, y)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
||||
x = batch[self.xcol]
|
||||
y = batch[self.ycol].reshape(-1, 1)
|
||||
x_hat = self.layers(x)
|
||||
loss = torch.nn.functional.mse_loss(x_hat, y)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self) -> torch.optim.Optimizer:
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
|
||||
class AestheticScore(torch.nn.Module):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.aes_model_path = path.get("aesthetic_predictor")
|
||||
# Load the MLP model
|
||||
self.model = MLP(768)
|
||||
try:
|
||||
if self.aes_model_path.endswith(".safetensors"):
|
||||
state_dict = load_file(self.aes_model_path)
|
||||
else:
|
||||
state_dict = torch.load(self.aes_model_path)
|
||||
self.model.load_state_dict(state_dict)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Load the CLIP model and processor
|
||||
clip_model_name = path.get('clip-large')
|
||||
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor) -> float:
|
||||
"""Calculate the aesthetic score for a single image.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
|
||||
Returns:
|
||||
float: The aesthetic score.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Get image embeddings
|
||||
image_embs = self.model2.get_image_features(image)
|
||||
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||
|
||||
# Compute score
|
||||
score = self.model(image_embs).cpu().flatten().item()
|
||||
|
||||
return score
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||
"""Score the images based on their aesthetic quality.
|
||||
|
||||
Args:
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
|
||||
Returns:
|
||||
List[float]: List of scores for the images.
|
||||
"""
|
||||
try:
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
pil_image = Image.open(images)
|
||||
else:
|
||||
pil_image = images
|
||||
|
||||
# Prepare image inputs
|
||||
image_inputs = self.processor(
|
||||
images=pil_image,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
return [self._calculate_score(image_inputs["pixel_values"])]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_image in images:
|
||||
if isinstance(one_image, str):
|
||||
pil_image = Image.open(one_image)
|
||||
elif isinstance(one_image, Image.Image):
|
||||
pil_image = one_image
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
|
||||
# Prepare image inputs
|
||||
image_inputs = self.processor(
|
||||
images=pil_image,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
scores.append(self._calculate_score(image_inputs["pixel_values"]))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in scoring images: {e}")
|
||||
97
diffsynth/extensions/ImageQualityMetric/clip.py
Normal file
97
diffsynth/extensions/ImageQualityMetric/clip.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
import torch
|
||||
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class CLIPScore(torch.nn.Module):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
super().__init__()
|
||||
"""Initialize the CLIPScore with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to load the model on.
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
# Create model and transforms
|
||||
self.model, _, self.preprocess_val = create_model_and_transforms(
|
||||
"ViT-H-14",
|
||||
# "laion2B-s32B-b79K",
|
||||
pretrained=path.get("open_clip"),
|
||||
precision="amp",
|
||||
device=device,
|
||||
jit=False,
|
||||
force_quick_gelu=False,
|
||||
force_custom_text=False,
|
||||
force_patch_dropout=False,
|
||||
force_image_size=None,
|
||||
pretrained_image=False,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
light_augmentation=True,
|
||||
aug_cfg={},
|
||||
output_dict=True,
|
||||
with_score_predictor=False,
|
||||
with_region_predictor=False,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||
self.model = self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||
"""Calculate the CLIP score for a single image and prompt.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
float: The CLIP score.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Process the prompt
|
||||
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
||||
|
||||
# Calculate the CLIP score
|
||||
outputs = self.model(image, text)
|
||||
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
||||
logits_per_image = image_features @ text_features.T
|
||||
clip_score = torch.diagonal(logits_per_image).cpu().numpy()
|
||||
|
||||
return clip_score[0].item()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of CLIP scores for the images.
|
||||
"""
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_images in images:
|
||||
if isinstance(one_images, str):
|
||||
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
elif isinstance(one_images, Image.Image):
|
||||
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
23
diffsynth/extensions/ImageQualityMetric/config.py
Normal file
23
diffsynth/extensions/ImageQualityMetric/config.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
|
||||
|
||||
def get_model_path(model_name):
|
||||
return os.path.join(model_path, model_name)
|
||||
|
||||
|
||||
MODEL_PATHS = {
|
||||
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
||||
"open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
||||
"hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
|
||||
"hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
|
||||
"imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
|
||||
"med_config": get_model_path("ImageReward/med_config.json"),
|
||||
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||
"clip-large": get_model_path("clip-vit-large-patch14"),
|
||||
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||
"pickscore": get_model_path("PickScore_v1")
|
||||
}
|
||||
118
diffsynth/extensions/ImageQualityMetric/hps.py
Normal file
118
diffsynth/extensions/ImageQualityMetric/hps.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
import torch
|
||||
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||
from safetensors.torch import load_file
|
||||
import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class HPScore_v2(torch.nn.Module):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||
super().__init__()
|
||||
"""Initialize the Selector with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to load the model on.
|
||||
model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
if model_version == "v2":
|
||||
safetensors_path = path.get("hpsv2")
|
||||
elif model_version == "v21":
|
||||
safetensors_path = path.get("hpsv2.1")
|
||||
else:
|
||||
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
||||
|
||||
# Create model and transforms
|
||||
model, _, self.preprocess_val = create_model_and_transforms(
|
||||
"ViT-H-14",
|
||||
# "laion2B-s32B-b79K",
|
||||
pretrained=path.get("open_clip"),
|
||||
precision="amp",
|
||||
device=device,
|
||||
jit=False,
|
||||
force_quick_gelu=False,
|
||||
force_custom_text=False,
|
||||
force_patch_dropout=False,
|
||||
force_image_size=None,
|
||||
pretrained_image=False,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
light_augmentation=True,
|
||||
aug_cfg={},
|
||||
output_dict=True,
|
||||
with_score_predictor=False,
|
||||
with_region_predictor=False,
|
||||
)
|
||||
|
||||
# Load model weights
|
||||
try:
|
||||
state_dict = load_file(safetensors_path)
|
||||
model.load_state_dict(state_dict)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
||||
|
||||
# Initialize tokenizer and model
|
||||
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||
"""Calculate the HPS score for a single image and prompt.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
float: The HPS score.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Process the prompt
|
||||
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
||||
|
||||
# Calculate the HPS score
|
||||
outputs = self.model(image, text)
|
||||
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
||||
logits_per_image = image_features @ text_features.T
|
||||
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
||||
|
||||
return hps_score[0].item()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of HPS scores for the images.
|
||||
"""
|
||||
try:
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_images in images:
|
||||
if isinstance(one_images, str):
|
||||
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
elif isinstance(one_images, Image.Image):
|
||||
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in scoring images: {e}")
|
||||
212
diffsynth/extensions/ImageQualityMetric/imagereward.py
Normal file
212
diffsynth/extensions/ImageQualityMetric/imagereward.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import List, Union
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from .BLIP.blip_pretrain import BLIP_Pretrain
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from safetensors.torch import load_file
|
||||
from .config import MODEL_PATHS
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
|
||||
def _convert_image_to_rgb(image):
|
||||
return image.convert("RGB")
|
||||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
_convert_image_to_rgb,
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.input_size, 1024),
|
||||
#nn.ReLU(),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(1024, 128),
|
||||
#nn.ReLU(),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(128, 64),
|
||||
#nn.ReLU(),
|
||||
torch.nn.Dropout(0.1),
|
||||
torch.nn.Linear(64, 16),
|
||||
#nn.ReLU(),
|
||||
torch.nn.Linear(16, 1)
|
||||
)
|
||||
|
||||
# initial MLP param
|
||||
for name, param in self.layers.named_parameters():
|
||||
if 'weight' in name:
|
||||
torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
|
||||
if 'bias' in name:
|
||||
torch.nn.init.constant_(param, val=0)
|
||||
|
||||
def forward(self, input):
|
||||
return self.layers(input)
|
||||
|
||||
class ImageReward(torch.nn.Module):
|
||||
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
|
||||
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
||||
self.preprocess = _transform(224)
|
||||
self.mlp = MLP(768)
|
||||
|
||||
self.mean = 0.16717362830052426
|
||||
self.std = 1.0333394966054072
|
||||
|
||||
def score_grad(self, prompt_ids, prompt_attention_mask, image):
|
||||
"""Calculate the score with gradient for a single image and prompt.
|
||||
|
||||
Args:
|
||||
prompt_ids (torch.Tensor): Tokenized prompt IDs.
|
||||
prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reward score.
|
||||
"""
|
||||
image_embeds = self.blip.visual_encoder(image)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||
text_output = self.blip.text_encoder(
|
||||
prompt_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
txt_features = text_output.last_hidden_state[:, 0, :]
|
||||
rewards = self.mlp(txt_features)
|
||||
rewards = (rewards - self.mean) / self.std
|
||||
return rewards
|
||||
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt text.
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
|
||||
Returns:
|
||||
List[float]: List of scores for the images.
|
||||
"""
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
pil_image = Image.open(images)
|
||||
else:
|
||||
pil_image = images
|
||||
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||
return [self._calculate_score(prompt, image).item()]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_image in images:
|
||||
if isinstance(one_image, str):
|
||||
pil_image = Image.open(one_image)
|
||||
elif isinstance(one_image, Image.Image):
|
||||
pil_image = one_image
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||
scores.append(self._calculate_score(prompt, image).item())
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
|
||||
def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the score for a single image and prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt text.
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reward score.
|
||||
"""
|
||||
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
||||
image_embeds = self.blip.visual_encoder(image)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||
text_output = self.blip.text_encoder(
|
||||
text_input.input_ids,
|
||||
attention_mask=text_input.attention_mask,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
txt_features = text_output.last_hidden_state[:, 0, :].float()
|
||||
rewards = self.mlp(txt_features)
|
||||
rewards = (rewards - self.mean) / self.std
|
||||
return rewards
|
||||
|
||||
def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
|
||||
"""Rank the images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt text.
|
||||
generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
|
||||
|
||||
Returns:
|
||||
tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
|
||||
"""
|
||||
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
||||
txt_set = []
|
||||
for generation in generations_list:
|
||||
if isinstance(generation, str):
|
||||
pil_image = Image.open(generation)
|
||||
elif isinstance(generation, Image.Image):
|
||||
pil_image = generation
|
||||
else:
|
||||
raise TypeError("The type of parameter generations_list is illegal.")
|
||||
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||
image_embeds = self.blip.visual_encoder(image)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||
text_output = self.blip.text_encoder(
|
||||
text_input.input_ids,
|
||||
attention_mask=text_input.attention_mask,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
||||
txt_features = torch.cat(txt_set, 0).float()
|
||||
rewards = self.mlp(txt_features)
|
||||
rewards = (rewards - self.mean) / self.std
|
||||
rewards = torch.squeeze(rewards)
|
||||
_, rank = torch.sort(rewards, dim=0, descending=True)
|
||||
_, indices = torch.sort(rank, dim=0)
|
||||
indices = indices + 1
|
||||
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
||||
|
||||
|
||||
class ImageRewardScore(torch.nn.Module):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
super().__init__()
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
model_path = path.get("imagereward")
|
||||
med_config = path.get("med_config")
|
||||
state_dict = load_file(model_path)
|
||||
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of scores for the images.
|
||||
"""
|
||||
return self.model.score(images, prompt)
|
||||
129
diffsynth/extensions/ImageQualityMetric/mps.py
Normal file
129
diffsynth/extensions/ImageQualityMetric/mps.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
||||
from transformers import CLIPConfig
|
||||
from dataclasses import dataclass
|
||||
from transformers import CLIPModel as HFCLIPModel
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn, einsum
|
||||
|
||||
from .trainer.models.base_model import BaseModelConfig
|
||||
|
||||
from transformers import CLIPConfig
|
||||
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
||||
from typing import Any, Optional, Tuple, Union, List
|
||||
import torch
|
||||
|
||||
from .trainer.models.cross_modeling import Cross_model
|
||||
from .trainer.models import clip_model
|
||||
import torch.nn.functional as F
|
||||
import gc
|
||||
import json
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class MPScore(torch.nn.Module):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||
super().__init__()
|
||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device
|
||||
processor_name_or_path = path.get("clip")
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
||||
state_dict = load_file(path.get("mps"))
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model.to(device)
|
||||
self.condition = condition
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||
"""Calculate the reward score for a single image and prompt.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
float: The reward score.
|
||||
"""
|
||||
def _tokenize(caption):
|
||||
input_ids = self.tokenizer(
|
||||
caption,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
).input_ids
|
||||
return input_ids
|
||||
|
||||
text_input = _tokenize(prompt).to(self.device)
|
||||
if self.condition == 'overall':
|
||||
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
|
||||
elif self.condition == 'aesthetics':
|
||||
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
|
||||
elif self.condition == 'quality':
|
||||
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
|
||||
elif self.condition == 'semantic':
|
||||
condition_prompt = 'quantity, attributes, position, number, location'
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
|
||||
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_f, text_features = self.model.model.get_text_features(text_input)
|
||||
|
||||
image_f = self.model.model.get_image_features(image.half())
|
||||
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
||||
|
||||
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
||||
mask = mask.repeat(1, image_f.shape[1], 1)
|
||||
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
||||
|
||||
return image_score[0].cpu().numpy().item()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of reward scores for the images.
|
||||
"""
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||
else:
|
||||
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_images in images:
|
||||
if isinstance(one_images, str):
|
||||
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||
elif isinstance(one_images, Image.Image):
|
||||
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
@@ -0,0 +1,14 @@
|
||||
from .coca_model import CoCa
|
||||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
||||
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
||||
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
||||
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
||||
from .openai import load_openai_model, list_openai_models
|
||||
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
||||
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
||||
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
||||
from .tokenizer import SimpleTokenizer
|
||||
from .transform import image_transform, AugmentationCfg
|
||||
from .utils import freeze_batch_norm_2d
|
||||
458
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
Normal file
458
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
Normal file
@@ -0,0 +1,458 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .transformer import (
|
||||
LayerNormFp32,
|
||||
LayerNorm,
|
||||
QuickGELU,
|
||||
MultimodalTransformer,
|
||||
)
|
||||
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
BeamSearchScorer,
|
||||
LogitsProcessorList,
|
||||
TopPLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
MinLengthLogitsProcessor,
|
||||
MaxLengthCriteria,
|
||||
StoppingCriteriaList
|
||||
)
|
||||
|
||||
GENERATION_TYPES = {
|
||||
"top_k": TopKLogitsWarper,
|
||||
"top_p": TopPLogitsWarper,
|
||||
"beam_search": "beam_search"
|
||||
}
|
||||
_has_transformers = True
|
||||
except ImportError as e:
|
||||
GENERATION_TYPES = {
|
||||
"top_k": None,
|
||||
"top_p": None,
|
||||
"beam_search": "beam_search"
|
||||
}
|
||||
_has_transformers = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultimodalCfg(CLIPTextCfg):
|
||||
mlp_ratio: int = 4
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
n_queries: int = 256
|
||||
attn_pooler_heads: int = 8
|
||||
|
||||
|
||||
def _build_text_decoder_tower(
|
||||
embed_dim,
|
||||
multimodal_cfg,
|
||||
quick_gelu: bool = False,
|
||||
cast_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
||||
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||
norm_layer = (
|
||||
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||
)
|
||||
|
||||
decoder = MultimodalTransformer(
|
||||
context_length=multimodal_cfg.context_length,
|
||||
width=multimodal_cfg.width,
|
||||
heads=multimodal_cfg.heads,
|
||||
layers=multimodal_cfg.layers,
|
||||
ls_init_value=multimodal_cfg.ls_init_value,
|
||||
output_dim=embed_dim,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
class CoCa(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
multimodal_cfg: MultimodalCfg,
|
||||
text_cfg: CLIPTextCfg,
|
||||
vision_cfg: CLIPVisionCfg,
|
||||
quick_gelu: bool = False,
|
||||
cast_dtype: Optional[torch.dtype] = None,
|
||||
pad_id: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
||||
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
||||
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
||||
|
||||
self.text = _build_text_tower(
|
||||
embed_dim=embed_dim,
|
||||
text_cfg=text_cfg,
|
||||
quick_gelu=quick_gelu,
|
||||
cast_dtype=cast_dtype,
|
||||
)
|
||||
|
||||
vocab_size = (
|
||||
text_cfg.vocab_size # for hf models
|
||||
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
||||
else text_cfg.vocab_size
|
||||
)
|
||||
|
||||
self.visual = _build_vision_tower(
|
||||
embed_dim=embed_dim,
|
||||
vision_cfg=vision_cfg,
|
||||
quick_gelu=quick_gelu,
|
||||
cast_dtype=cast_dtype,
|
||||
)
|
||||
|
||||
self.text_decoder = _build_text_decoder_tower(
|
||||
vocab_size,
|
||||
multimodal_cfg=multimodal_cfg,
|
||||
quick_gelu=quick_gelu,
|
||||
cast_dtype=cast_dtype,
|
||||
)
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
self.pad_id = pad_id
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.visual.set_grad_checkpointing(enable)
|
||||
self.text.set_grad_checkpointing(enable)
|
||||
self.text_decoder.set_grad_checkpointing(enable)
|
||||
|
||||
def _encode_image(self, images, normalize=True):
|
||||
image_latent, tokens_embs = self.visual(images)
|
||||
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
||||
return image_latent, tokens_embs
|
||||
|
||||
def _encode_text(self, text, normalize=True, embed_cls=True):
|
||||
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
||||
text_latent, token_emb = self.text(text)
|
||||
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
||||
return text_latent, token_emb
|
||||
|
||||
def encode_image(self, images, normalize=True):
|
||||
image_latent, _ = self._encode_image(images, normalize=normalize)
|
||||
return image_latent
|
||||
|
||||
def encode_text(self, text, normalize=True, embed_cls=True):
|
||||
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
||||
return text_latent
|
||||
|
||||
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
||||
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
||||
if image_latent is None or image_embs is None:
|
||||
image_latent, image_embs = self._encode_image(image)
|
||||
|
||||
# TODO: add assertion to avoid bugs?
|
||||
labels = text[:, -token_embs.shape[1]:]
|
||||
|
||||
logits = self.text_decoder(image_embs, token_embs)
|
||||
return {
|
||||
"image_features": image_latent,
|
||||
"text_features": text_latent,
|
||||
"logits": logits,
|
||||
"labels": labels,
|
||||
"logit_scale": self.logit_scale.exp()
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
image,
|
||||
text=None,
|
||||
seq_len=30,
|
||||
max_seq_len=77,
|
||||
temperature=1.,
|
||||
generation_type="beam_search",
|
||||
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
||||
top_k=1, # keeps the top_k most probable tokens
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
sot_token_id=None,
|
||||
num_beams=6,
|
||||
num_beam_groups=3,
|
||||
min_seq_len=5,
|
||||
stopping_criteria=None,
|
||||
repetition_penalty=1.0,
|
||||
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
||||
):
|
||||
# taking many ideas and components from HuggingFace GenerationMixin
|
||||
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
||||
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
||||
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
||||
|
||||
with torch.no_grad():
|
||||
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
||||
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
||||
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
||||
logit_processor = LogitsProcessorList(
|
||||
[
|
||||
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
||||
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
||||
]
|
||||
)
|
||||
|
||||
if stopping_criteria is None:
|
||||
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
||||
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
stopping_criteria
|
||||
)
|
||||
|
||||
device = image.device
|
||||
|
||||
if generation_type == "beam_search":
|
||||
output = self._generate_beamsearch(
|
||||
image_inputs = image,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
sot_token_id=sot_token_id,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
min_seq_len=min_seq_len,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logit_processor=logit_processor,
|
||||
)
|
||||
if fixed_output_length and output.shape[1] < seq_len:
|
||||
return torch.cat(
|
||||
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
||||
dim=1
|
||||
)
|
||||
return output
|
||||
|
||||
elif generation_type == "top_p":
|
||||
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
||||
elif generation_type == "top_k":
|
||||
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"generation_type has to be one of "
|
||||
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
||||
)
|
||||
|
||||
image_latent, image_embs = self._encode_image(image)
|
||||
|
||||
if text is None:
|
||||
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
||||
|
||||
was_training = self.training
|
||||
num_dims = len(text.shape)
|
||||
|
||||
if num_dims == 1:
|
||||
text = text[None, :]
|
||||
|
||||
cur_len = text.shape[1]
|
||||
self.eval()
|
||||
out = text
|
||||
|
||||
while True:
|
||||
x = out[:, -max_seq_len:]
|
||||
cur_len = x.shape[1]
|
||||
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
||||
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
||||
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
||||
|
||||
if mask.all():
|
||||
if not fixed_output_length:
|
||||
break
|
||||
else:
|
||||
logits = logits[~mask, :]
|
||||
filtered_logits = logit_processor(x[~mask, :], logits)
|
||||
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
||||
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
||||
|
||||
if (cur_len + 1 == seq_len):
|
||||
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
||||
else:
|
||||
sample[~mask, :] = torch.multinomial(probs, 1)
|
||||
|
||||
out = torch.cat((out, sample), dim=-1)
|
||||
|
||||
cur_len += 1
|
||||
|
||||
if stopping_criteria(out, None):
|
||||
break
|
||||
|
||||
if num_dims == 1:
|
||||
out = out.squeeze(0)
|
||||
|
||||
self.train(was_training)
|
||||
return out
|
||||
|
||||
def _generate_beamsearch(
|
||||
self,
|
||||
image_inputs,
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
sot_token_id=None,
|
||||
num_beams=6,
|
||||
num_beam_groups=3,
|
||||
min_seq_len=5,
|
||||
stopping_criteria=None,
|
||||
logit_processor=None,
|
||||
logit_warper=None,
|
||||
):
|
||||
device = image_inputs.device
|
||||
batch_size = image_inputs.shape[0]
|
||||
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
||||
image_latent, image_embs = self._encode_image(image_inputs)
|
||||
|
||||
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
||||
input_ids = input_ids * sot_token_id
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=device,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
# instantiate logits processors
|
||||
logits_processor = (
|
||||
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
||||
if logit_processor is None
|
||||
else logit_processor
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
num_beam_groups = beam_scorer.num_beam_groups
|
||||
num_sub_beams = num_beams // num_beam_groups
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
beam_indices = None
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||
# the same group don't produce same tokens everytime.
|
||||
beam_scores[:, ::num_sub_beams] = 0
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
while True:
|
||||
|
||||
# predicted tokens in cur_len step
|
||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||
|
||||
# indices which will form the beams in the next time step
|
||||
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
||||
|
||||
# do one decoder step on all beams of all sentences in batch
|
||||
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
||||
outputs = self(
|
||||
model_inputs['images'],
|
||||
model_inputs['text'],
|
||||
embed_cls=False,
|
||||
image_latent=image_latent,
|
||||
image_embs=image_embs
|
||||
)
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||
group_size = group_end_idx - group_start_idx
|
||||
|
||||
# indices of beams of current group among all sentences in batch
|
||||
batch_group_indices = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_group_indices.extend(
|
||||
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||
)
|
||||
group_input_ids = input_ids[batch_group_indices]
|
||||
|
||||
# select outputs of beams of currentg group only
|
||||
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
||||
vocab_size = next_token_logits.shape[-1]
|
||||
|
||||
next_token_scores_processed = logits_processor(
|
||||
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||
)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
||||
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
||||
|
||||
# reshape for beam search
|
||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||
beam_outputs = beam_scorer.process(
|
||||
group_input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=process_beam_indices,
|
||||
)
|
||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||
|
||||
# (beam_idx // group_size) -> batch_idx
|
||||
# (beam_idx % group_size) -> offset of idx inside the group
|
||||
reordering_indices[batch_group_indices] = (
|
||||
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
||||
break
|
||||
|
||||
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=final_beam_indices,
|
||||
)
|
||||
return sequence_outputs['sequences']
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"text": input_ids,
|
||||
"images": image_inputs,
|
||||
"past_key_values": past,
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
433
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
Normal file
433
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
Normal file
@@ -0,0 +1,433 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
# from turtle import forward
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
||||
resize_pos_embed, get_cast_dtype
|
||||
from .coca_model import CoCa
|
||||
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||
from .openai import load_openai_model
|
||||
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
||||
from .transform import image_transform, AugmentationCfg
|
||||
from .tokenizer import HFTokenizer, SimpleTokenizer
|
||||
|
||||
|
||||
HF_HUB_PREFIX = 'hf-hub:'
|
||||
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
||||
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def _rescan_model_configs():
|
||||
global _MODEL_CONFIGS
|
||||
|
||||
config_ext = ('.json',)
|
||||
config_files = []
|
||||
for config_path in _MODEL_CONFIG_PATHS:
|
||||
if config_path.is_file() and config_path.suffix in config_ext:
|
||||
config_files.append(config_path)
|
||||
elif config_path.is_dir():
|
||||
for ext in config_ext:
|
||||
config_files.extend(config_path.glob(f'*{ext}'))
|
||||
|
||||
for cf in config_files:
|
||||
with open(cf, 'r') as f:
|
||||
model_cfg = json.load(f)
|
||||
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
||||
_MODEL_CONFIGS[cf.stem] = model_cfg
|
||||
|
||||
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
||||
|
||||
|
||||
_rescan_model_configs() # initial populate of model config registry
|
||||
|
||||
|
||||
def list_models():
|
||||
""" enumerate available model architectures based on config files """
|
||||
return list(_MODEL_CONFIGS.keys())
|
||||
|
||||
|
||||
def add_model_config(path):
|
||||
""" add model config path or file and update registry """
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
_MODEL_CONFIG_PATHS.append(path)
|
||||
_rescan_model_configs()
|
||||
|
||||
|
||||
def get_model_config(model_name):
|
||||
if model_name in _MODEL_CONFIGS:
|
||||
return deepcopy(_MODEL_CONFIGS[model_name])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
||||
if model_name.startswith(HF_HUB_PREFIX):
|
||||
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
||||
else:
|
||||
config = get_model_config(model_name)
|
||||
tokenizer = HFTokenizer(
|
||||
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
if next(iter(state_dict.items()))[0].startswith('module'):
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, strict=True):
|
||||
state_dict = load_state_dict(checkpoint_path)
|
||||
# detect old format and make compatible with new format
|
||||
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
||||
state_dict = convert_to_custom_text_state_dict(state_dict)
|
||||
resize_pos_embed(state_dict, model)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: Optional[str] = None,
|
||||
precision: str = 'fp32',
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
jit: bool = False,
|
||||
force_quick_gelu: bool = False,
|
||||
force_custom_text: bool = False,
|
||||
force_patch_dropout: Optional[float] = None,
|
||||
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
pretrained_image: bool = False,
|
||||
pretrained_hf: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
output_dict: Optional[bool] = None,
|
||||
require_pretrained: bool = False,
|
||||
):
|
||||
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
||||
if has_hf_hub_prefix:
|
||||
model_id = model_name[len(HF_HUB_PREFIX):]
|
||||
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
||||
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
pretrained_cfg = config['preprocess_cfg']
|
||||
model_cfg = config['model_cfg']
|
||||
else:
|
||||
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
||||
checkpoint_path = None
|
||||
pretrained_cfg = {}
|
||||
model_cfg = None
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if pretrained and pretrained.lower() == 'openai':
|
||||
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
||||
model = load_openai_model(
|
||||
model_name,
|
||||
precision=precision,
|
||||
device=device,
|
||||
jit=jit,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
# to always output dict even if it is clip
|
||||
if output_dict and hasattr(model, "output_dict"):
|
||||
model.output_dict = True
|
||||
else:
|
||||
model_cfg = model_cfg or get_model_config(model_name)
|
||||
if model_cfg is not None:
|
||||
logging.info(f'Loaded {model_name} model config.')
|
||||
else:
|
||||
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
||||
raise RuntimeError(f'Model config for {model_name} not found.')
|
||||
|
||||
if force_quick_gelu:
|
||||
# override for use of QuickGELU on non-OpenAI transformer models
|
||||
model_cfg["quick_gelu"] = True
|
||||
|
||||
if force_patch_dropout is not None:
|
||||
# override the default patch dropout value
|
||||
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
||||
|
||||
if force_image_size is not None:
|
||||
# override model config's image size
|
||||
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
||||
|
||||
if pretrained_image:
|
||||
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
||||
# pretrained weight loading for timm models set via vision_cfg
|
||||
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
||||
else:
|
||||
assert False, 'pretrained image towers currently only supported for timm models'
|
||||
|
||||
cast_dtype = get_cast_dtype(precision)
|
||||
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
||||
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
||||
|
||||
if custom_text:
|
||||
if is_hf_model:
|
||||
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
||||
if "coca" in model_name:
|
||||
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
||||
else:
|
||||
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
||||
else:
|
||||
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
||||
|
||||
pretrained_loaded = False
|
||||
if pretrained:
|
||||
checkpoint_path = ''
|
||||
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
||||
if pretrained_cfg:
|
||||
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
||||
elif os.path.exists(pretrained):
|
||||
checkpoint_path = pretrained
|
||||
|
||||
if checkpoint_path:
|
||||
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
else:
|
||||
error_str = (
|
||||
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
||||
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
||||
logging.warning(error_str)
|
||||
raise RuntimeError(error_str)
|
||||
pretrained_loaded = True
|
||||
elif has_hf_hub_prefix:
|
||||
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
pretrained_loaded = True
|
||||
|
||||
if require_pretrained and not pretrained_loaded:
|
||||
# callers of create_model_from_pretrained always expect pretrained weights
|
||||
raise RuntimeError(
|
||||
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
||||
|
||||
model.to(device=device)
|
||||
if precision in ("fp16", "bf16"):
|
||||
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
|
||||
|
||||
# set image / mean metadata from pretrained_cfg if available, or use default
|
||||
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
||||
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
||||
|
||||
# to always output dict even if it is clip
|
||||
if output_dict and hasattr(model, "output_dict"):
|
||||
model.output_dict = True
|
||||
|
||||
if jit:
|
||||
model = torch.jit.script(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_loss(args):
|
||||
if args.distill:
|
||||
return DistillClipLoss(
|
||||
local_loss=args.local_loss,
|
||||
gather_with_grad=args.gather_with_grad,
|
||||
cache_labels=True,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
use_horovod=args.horovod,
|
||||
)
|
||||
elif "coca" in args.model.lower():
|
||||
return CoCaLoss(
|
||||
caption_loss_weight=args.coca_caption_loss_weight,
|
||||
clip_loss_weight=args.coca_contrastive_loss_weight,
|
||||
local_loss=args.local_loss,
|
||||
gather_with_grad=args.gather_with_grad,
|
||||
cache_labels=True,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
use_horovod=args.horovod,
|
||||
)
|
||||
return ClipLoss(
|
||||
local_loss=args.local_loss,
|
||||
gather_with_grad=args.gather_with_grad,
|
||||
cache_labels=True,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
use_horovod=args.horovod,
|
||||
)
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.input_size, 1024),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(1024, 128),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(128, 64),
|
||||
torch.nn.Dropout(0.1),
|
||||
torch.nn.Linear(64, 16),
|
||||
torch.nn.Linear(16, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
# class semantic_head(torch.nn.Module):
|
||||
# def __init__(self, input_size):
|
||||
# super().__init__()
|
||||
# self.input_size = input_size # for ViT-L-14 is 1024
|
||||
# self.seg_head = torch.nn.Sequential(
|
||||
# torch.nn.Linear(input_size, 128),
|
||||
# torch.nn.Dropout(0.2),
|
||||
# torch.nn.Linear(128, 64),
|
||||
# torch.nn.Dropout(0.1),
|
||||
# torch.nn.Linear(64, 16),
|
||||
# torch.nn.Linear(16, 1),
|
||||
# )
|
||||
# self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
# def forward(self, x):
|
||||
# return self.sigmoid(self.seg_head(x))
|
||||
|
||||
def create_model_and_transforms(
|
||||
model_name: str,
|
||||
pretrained: Optional[str] = None,
|
||||
precision: str = 'fp32',
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
jit: bool = False,
|
||||
force_quick_gelu: bool = False,
|
||||
force_custom_text: bool = False,
|
||||
force_patch_dropout: Optional[float] = None,
|
||||
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
pretrained_image: bool = False,
|
||||
pretrained_hf: bool = True,
|
||||
image_mean: Optional[Tuple[float, ...]] = None,
|
||||
image_std: Optional[Tuple[float, ...]] = None,
|
||||
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
light_augmentation = False,
|
||||
output_dict: Optional[bool] = None,
|
||||
with_score_predictor: bool = False,
|
||||
with_region_predictor: bool = False
|
||||
):
|
||||
model = create_model(
|
||||
model_name,
|
||||
pretrained,
|
||||
precision=precision,
|
||||
device=device,
|
||||
jit=jit,
|
||||
force_quick_gelu=force_quick_gelu,
|
||||
force_custom_text=force_custom_text,
|
||||
force_patch_dropout=force_patch_dropout,
|
||||
force_image_size=force_image_size,
|
||||
pretrained_image=pretrained_image,
|
||||
pretrained_hf=pretrained_hf,
|
||||
cache_dir=cache_dir,
|
||||
output_dict=output_dict,
|
||||
)
|
||||
|
||||
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
||||
image_std = image_std or getattr(model.visual, 'image_std', None)
|
||||
|
||||
if with_score_predictor:
|
||||
model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
||||
|
||||
if with_region_predictor:
|
||||
# model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
||||
model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
|
||||
# preprocess_train = image_transform_region(
|
||||
# model.visual.image_size,
|
||||
# is_train=True,
|
||||
# mean=image_mean,
|
||||
# std=image_std
|
||||
# )
|
||||
# preprocess_val = image_transform_region(
|
||||
# model.visual.image_size,
|
||||
# is_train=False,
|
||||
# mean=image_mean,
|
||||
# std=image_std
|
||||
# )
|
||||
|
||||
if light_augmentation:
|
||||
preprocess_val = image_transform(
|
||||
model.visual.image_size,
|
||||
is_train=False,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
resize_longest_max=True,
|
||||
)
|
||||
preprocess_train = preprocess_val
|
||||
else:
|
||||
preprocess_train = image_transform(
|
||||
model.visual.image_size,
|
||||
is_train=True,
|
||||
mean=image_mean,
|
||||
std=image_std
|
||||
)
|
||||
preprocess_val = image_transform(
|
||||
model.visual.image_size,
|
||||
is_train=False,
|
||||
mean=image_mean,
|
||||
std=image_std
|
||||
)
|
||||
|
||||
return model, preprocess_train, preprocess_val
|
||||
|
||||
|
||||
def create_model_from_pretrained(
|
||||
model_name: str,
|
||||
pretrained: Optional[str] = None,
|
||||
precision: str = 'fp32',
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
jit: bool = False,
|
||||
force_quick_gelu: bool = False,
|
||||
force_custom_text: bool = False,
|
||||
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
return_transform: bool = True,
|
||||
image_mean: Optional[Tuple[float, ...]] = None,
|
||||
image_std: Optional[Tuple[float, ...]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
model = create_model(
|
||||
model_name,
|
||||
pretrained,
|
||||
precision=precision,
|
||||
device=device,
|
||||
jit=jit,
|
||||
force_quick_gelu=force_quick_gelu,
|
||||
force_custom_text=force_custom_text,
|
||||
force_image_size=force_image_size,
|
||||
cache_dir=cache_dir,
|
||||
require_pretrained=True,
|
||||
)
|
||||
|
||||
if not return_transform:
|
||||
return model
|
||||
|
||||
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
||||
image_std = image_std or getattr(model.visual, 'image_std', None)
|
||||
preprocess = image_transform(
|
||||
model.visual.image_size,
|
||||
is_train=False,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
)
|
||||
|
||||
return model, preprocess
|
||||
@@ -0,0 +1,45 @@
|
||||
# HF architecture dict:
|
||||
arch_dict = {
|
||||
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
||||
"roberta": {
|
||||
"config_names": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"vocab_size": "vocab_size",
|
||||
"width": "hidden_size",
|
||||
"heads": "num_attention_heads",
|
||||
"layers": "num_hidden_layers",
|
||||
"layer_attr": "layer",
|
||||
"token_embeddings_attr": "embeddings"
|
||||
},
|
||||
"pooler": "mean_pooler",
|
||||
},
|
||||
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
||||
"xlm-roberta": {
|
||||
"config_names": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"vocab_size": "vocab_size",
|
||||
"width": "hidden_size",
|
||||
"heads": "num_attention_heads",
|
||||
"layers": "num_hidden_layers",
|
||||
"layer_attr": "layer",
|
||||
"token_embeddings_attr": "embeddings"
|
||||
},
|
||||
"pooler": "mean_pooler",
|
||||
},
|
||||
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
||||
"mt5": {
|
||||
"config_names": {
|
||||
# unlimited seqlen
|
||||
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
||||
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
||||
"context_length": "",
|
||||
"vocab_size": "vocab_size",
|
||||
"width": "d_model",
|
||||
"heads": "num_heads",
|
||||
"layers": "num_layers",
|
||||
"layer_attr": "block",
|
||||
"token_embeddings_attr": "embed_tokens"
|
||||
},
|
||||
"pooler": "mean_pooler",
|
||||
},
|
||||
}
|
||||
176
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
Normal file
176
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
Normal file
@@ -0,0 +1,176 @@
|
||||
""" huggingface model adapter
|
||||
|
||||
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import TensorType
|
||||
|
||||
try:
|
||||
import transformers
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
||||
BaseModelOutputWithPoolingAndCrossAttentions
|
||||
except ImportError as e:
|
||||
transformers = None
|
||||
|
||||
|
||||
class BaseModelOutput:
|
||||
pass
|
||||
|
||||
|
||||
class PretrainedConfig:
|
||||
pass
|
||||
|
||||
from .hf_configs import arch_dict
|
||||
|
||||
|
||||
# utils
|
||||
def _camel2snake(s):
|
||||
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
||||
|
||||
|
||||
# TODO: ?last - for gpt-like models
|
||||
_POOLERS = {}
|
||||
|
||||
|
||||
def register_pooler(cls):
|
||||
"""Decorator registering pooler class"""
|
||||
_POOLERS[_camel2snake(cls.__name__)] = cls
|
||||
return cls
|
||||
|
||||
|
||||
@register_pooler
|
||||
class MeanPooler(nn.Module):
|
||||
"""Mean pooling"""
|
||||
|
||||
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
||||
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@register_pooler
|
||||
class MaxPooler(nn.Module):
|
||||
"""Max pooling"""
|
||||
|
||||
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
||||
return masked_output.max(1).values
|
||||
|
||||
|
||||
@register_pooler
|
||||
class ClsPooler(nn.Module):
|
||||
"""CLS token pooling"""
|
||||
|
||||
def __init__(self, use_pooler_output=True):
|
||||
super().__init__()
|
||||
self.cls_token_position = 0
|
||||
self.use_pooler_output = use_pooler_output
|
||||
|
||||
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||
if (self.use_pooler_output and
|
||||
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
||||
(x.pooler_output is not None)
|
||||
):
|
||||
return x.pooler_output
|
||||
|
||||
return x.last_hidden_state[:, self.cls_token_position, :]
|
||||
|
||||
|
||||
class HFTextEncoder(nn.Module):
|
||||
"""HuggingFace model adapter"""
|
||||
output_tokens: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
output_dim: int,
|
||||
config: PretrainedConfig = None,
|
||||
pooler_type: str = None,
|
||||
proj: str = None,
|
||||
pretrained: bool = True,
|
||||
output_tokens: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_tokens = output_tokens
|
||||
self.output_dim = output_dim
|
||||
|
||||
# TODO: find better way to get this information
|
||||
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
||||
|
||||
if transformers is None:
|
||||
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
||||
if config is None:
|
||||
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
||||
AutoModel.from_config, self.config)
|
||||
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
||||
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
||||
self.transformer = create_func(model_args)
|
||||
self.transformer = self.transformer.encoder
|
||||
else:
|
||||
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
||||
else:
|
||||
self.config = config
|
||||
self.transformer = AutoModel.from_config(config)
|
||||
if pooler_type is None: # get default arch pooler
|
||||
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
||||
|
||||
self.pooler = _POOLERS[pooler_type]()
|
||||
|
||||
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
||||
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
||||
self.proj = nn.Identity()
|
||||
elif proj == 'linear':
|
||||
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
||||
elif proj == 'mlp':
|
||||
hidden_size = (d_model + output_dim) // 2
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(d_model, hidden_size, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_size, output_dim, bias=False),
|
||||
)
|
||||
|
||||
def forward(self, x: TensorType):
|
||||
attn_mask = (x != self.config.pad_token_id).long()
|
||||
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
||||
pooled_out = self.pooler(out, attn_mask)
|
||||
projected = self.proj(pooled_out)
|
||||
|
||||
seq_len = out.last_hidden_state.shape[1]
|
||||
tokens = (
|
||||
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
||||
if type(self.pooler) == ClsPooler
|
||||
else out.last_hidden_state
|
||||
)
|
||||
|
||||
if self.output_tokens:
|
||||
return projected, tokens
|
||||
return projected
|
||||
|
||||
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||
if not unlocked_layers: # full freezing
|
||||
for n, p in self.transformer.named_parameters():
|
||||
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||
return
|
||||
|
||||
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
||||
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
||||
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
||||
embeddings = getattr(
|
||||
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
||||
modules = [embeddings, *layer_list][:-unlocked_layers]
|
||||
# freeze layers
|
||||
for module in modules:
|
||||
for n, p in module.named_parameters():
|
||||
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.transformer.gradient_checkpointing_enable()
|
||||
|
||||
def init_parameters(self):
|
||||
pass
|
||||
270
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
Normal file
270
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
try:
|
||||
import torch.distributed.nn
|
||||
from torch import distributed as dist
|
||||
|
||||
has_distributed = True
|
||||
except ImportError:
|
||||
has_distributed = False
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
except ImportError:
|
||||
hvd = None
|
||||
|
||||
|
||||
def gather_features(
|
||||
image_features,
|
||||
text_features,
|
||||
local_loss=False,
|
||||
gather_with_grad=False,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
use_horovod=False
|
||||
):
|
||||
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
||||
if use_horovod:
|
||||
assert hvd is not None, 'Please install horovod'
|
||||
if gather_with_grad:
|
||||
all_image_features = hvd.allgather(image_features)
|
||||
all_text_features = hvd.allgather(text_features)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
all_image_features = hvd.allgather(image_features)
|
||||
all_text_features = hvd.allgather(text_features)
|
||||
if not local_loss:
|
||||
# ensure grads for local rank when all_* features don't have a gradient
|
||||
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
||||
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
||||
gathered_image_features[rank] = image_features
|
||||
gathered_text_features[rank] = text_features
|
||||
all_image_features = torch.cat(gathered_image_features, dim=0)
|
||||
all_text_features = torch.cat(gathered_text_features, dim=0)
|
||||
else:
|
||||
# We gather tensors from all gpus
|
||||
if gather_with_grad:
|
||||
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
||||
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
||||
else:
|
||||
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
||||
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_image_features, image_features)
|
||||
dist.all_gather(gathered_text_features, text_features)
|
||||
if not local_loss:
|
||||
# ensure grads for local rank when all_* features don't have a gradient
|
||||
gathered_image_features[rank] = image_features
|
||||
gathered_text_features[rank] = text_features
|
||||
all_image_features = torch.cat(gathered_image_features, dim=0)
|
||||
all_text_features = torch.cat(gathered_text_features, dim=0)
|
||||
|
||||
return all_image_features, all_text_features
|
||||
|
||||
|
||||
class ClipLoss(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_loss=False,
|
||||
gather_with_grad=False,
|
||||
cache_labels=False,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
use_horovod=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.local_loss = local_loss
|
||||
self.gather_with_grad = gather_with_grad
|
||||
self.cache_labels = cache_labels
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.use_horovod = use_horovod
|
||||
|
||||
# cache state
|
||||
self.prev_num_logits = 0
|
||||
self.labels = {}
|
||||
|
||||
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
||||
# calculated ground-truth and cache if enabled
|
||||
if self.prev_num_logits != num_logits or device not in self.labels:
|
||||
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
||||
if self.world_size > 1 and self.local_loss:
|
||||
labels = labels + num_logits * self.rank
|
||||
if self.cache_labels:
|
||||
self.labels[device] = labels
|
||||
self.prev_num_logits = num_logits
|
||||
else:
|
||||
labels = self.labels[device]
|
||||
return labels
|
||||
|
||||
def get_logits(self, image_features, text_features, logit_scale):
|
||||
if self.world_size > 1:
|
||||
all_image_features, all_text_features = gather_features(
|
||||
image_features, text_features,
|
||||
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
||||
|
||||
if self.local_loss:
|
||||
logits_per_image = logit_scale * image_features @ all_text_features.T
|
||||
logits_per_text = logit_scale * text_features @ all_image_features.T
|
||||
else:
|
||||
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
||||
logits_per_text = logits_per_image.T
|
||||
else:
|
||||
logits_per_image = logit_scale * image_features @ text_features.T
|
||||
logits_per_text = logit_scale * text_features @ image_features.T
|
||||
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
||||
device = image_features.device
|
||||
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
||||
|
||||
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
||||
|
||||
total_loss = (
|
||||
F.cross_entropy(logits_per_image, labels) +
|
||||
F.cross_entropy(logits_per_text, labels)
|
||||
) / 2
|
||||
return total_loss
|
||||
|
||||
class PreferenceLoss(nn.Module):
|
||||
|
||||
def forward(self, logits_per_image, num_images, labels):
|
||||
|
||||
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
||||
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
|
||||
|
||||
ce_loss = F.cross_entropy(paired_logits, labels)
|
||||
return ce_loss
|
||||
|
||||
class HPSLoss(nn.Module):
|
||||
|
||||
def forward(self, text_logits, labels):
|
||||
|
||||
device = text_logits.device
|
||||
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
|
||||
label_0, label_1 = labels.chunk(2, dim=-1)
|
||||
|
||||
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
|
||||
text_0_logits = text_0_logits[index, index]
|
||||
text_1_logits = text_1_logits[index, index]
|
||||
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
||||
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
|
||||
text_1_labels = text_0_labels + 1
|
||||
|
||||
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
|
||||
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
|
||||
|
||||
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
|
||||
|
||||
# absolute_example_weight = 1 / num_per_prompt
|
||||
# denominator = absolute_example_weight.sum()
|
||||
# weight_per_example = absolute_example_weight / denominator
|
||||
# text_loss *= weight_per_example
|
||||
|
||||
text_loss = text_loss.sum()
|
||||
return text_loss
|
||||
|
||||
class RankingLoss(nn.Module):
|
||||
|
||||
def forward(self, logits_per_image, num_images, labels, margin = 1.0):
|
||||
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
||||
label_list = [label for label in labels.split(num_images.tolist())]
|
||||
# ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
|
||||
|
||||
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
|
||||
padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
|
||||
|
||||
# regulized_logits = torch.log(torch.sigmoid(paired_logits))
|
||||
|
||||
diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
||||
# diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
||||
# diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
|
||||
diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
|
||||
mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
|
||||
|
||||
loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
|
||||
return loss
|
||||
|
||||
class CoCaLoss(ClipLoss):
|
||||
def __init__(
|
||||
self,
|
||||
caption_loss_weight,
|
||||
clip_loss_weight,
|
||||
pad_id=0, # pad_token for open_clip custom tokenizer
|
||||
local_loss=False,
|
||||
gather_with_grad=False,
|
||||
cache_labels=False,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
use_horovod=False,
|
||||
):
|
||||
super().__init__(
|
||||
local_loss=local_loss,
|
||||
gather_with_grad=gather_with_grad,
|
||||
cache_labels=cache_labels,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
use_horovod=use_horovod
|
||||
)
|
||||
|
||||
self.clip_loss_weight = clip_loss_weight
|
||||
self.caption_loss_weight = caption_loss_weight
|
||||
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
||||
|
||||
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
||||
clip_loss = super().forward(image_features, text_features, logit_scale)
|
||||
clip_loss = self.clip_loss_weight * clip_loss
|
||||
|
||||
caption_loss = self.caption_loss(
|
||||
logits.permute(0, 2, 1),
|
||||
labels,
|
||||
)
|
||||
caption_loss = caption_loss * self.caption_loss_weight
|
||||
|
||||
if output_dict:
|
||||
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
||||
|
||||
return clip_loss, caption_loss
|
||||
|
||||
|
||||
class DistillClipLoss(ClipLoss):
|
||||
|
||||
def dist_loss(self, teacher_logits, student_logits):
|
||||
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_features,
|
||||
text_features,
|
||||
logit_scale,
|
||||
dist_image_features,
|
||||
dist_text_features,
|
||||
dist_logit_scale,
|
||||
output_dict=False,
|
||||
):
|
||||
logits_per_image, logits_per_text = \
|
||||
self.get_logits(image_features, text_features, logit_scale)
|
||||
|
||||
dist_logits_per_image, dist_logits_per_text = \
|
||||
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
||||
|
||||
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
||||
|
||||
contrastive_loss = (
|
||||
F.cross_entropy(logits_per_image, labels) +
|
||||
F.cross_entropy(logits_per_text, labels)
|
||||
) / 2
|
||||
|
||||
distill_loss = (
|
||||
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
||||
self.dist_loss(dist_logits_per_text, logits_per_text)
|
||||
) / 2
|
||||
|
||||
if output_dict:
|
||||
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
||||
|
||||
return contrastive_loss, distill_loss
|
||||
461
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
Normal file
461
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
Normal file
@@ -0,0 +1,461 @@
|
||||
""" CLIP Model
|
||||
|
||||
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from .hf_model import HFTextEncoder
|
||||
from .modified_resnet import ModifiedResNet
|
||||
from .timm_model import TimmModel
|
||||
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
||||
from .utils import to_2tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPVisionCfg:
|
||||
layers: Union[Tuple[int, int, int, int], int] = 12
|
||||
width: int = 768
|
||||
head_width: int = 64
|
||||
mlp_ratio: float = 4.0
|
||||
patch_size: int = 16
|
||||
image_size: Union[Tuple[int, int], int] = 224
|
||||
ls_init_value: Optional[float] = None # layer scale initial value
|
||||
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
||||
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
||||
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
||||
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
||||
n_queries: int = 256 # n_queries for attentional pooler
|
||||
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
||||
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
||||
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
||||
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
||||
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
||||
timm_proj_bias: bool = False # enable bias final projection
|
||||
timm_drop: float = 0. # head dropout
|
||||
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
||||
output_tokens: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextCfg:
|
||||
context_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
width: int = 512
|
||||
heads: int = 8
|
||||
layers: int = 12
|
||||
ls_init_value: Optional[float] = None # layer scale initial value
|
||||
hf_model_name: str = None
|
||||
hf_tokenizer_name: str = None
|
||||
hf_model_pretrained: bool = True
|
||||
proj: str = 'mlp'
|
||||
pooler_type: str = 'mean_pooler'
|
||||
embed_cls: bool = False
|
||||
pad_id: int = 0
|
||||
output_tokens: bool = False
|
||||
|
||||
|
||||
def get_cast_dtype(precision: str):
|
||||
cast_dtype = None
|
||||
if precision == 'bf16':
|
||||
cast_dtype = torch.bfloat16
|
||||
elif precision == 'fp16':
|
||||
cast_dtype = torch.float16
|
||||
return cast_dtype
|
||||
|
||||
|
||||
def _build_vision_tower(
|
||||
embed_dim: int,
|
||||
vision_cfg: CLIPVisionCfg,
|
||||
quick_gelu: bool = False,
|
||||
cast_dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
if isinstance(vision_cfg, dict):
|
||||
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
||||
|
||||
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
||||
# memory efficient in recent PyTorch releases (>= 1.10).
|
||||
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
||||
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||
|
||||
if vision_cfg.timm_model_name:
|
||||
visual = TimmModel(
|
||||
vision_cfg.timm_model_name,
|
||||
pretrained=vision_cfg.timm_model_pretrained,
|
||||
pool=vision_cfg.timm_pool,
|
||||
proj=vision_cfg.timm_proj,
|
||||
proj_bias=vision_cfg.timm_proj_bias,
|
||||
drop=vision_cfg.timm_drop,
|
||||
drop_path=vision_cfg.timm_drop_path,
|
||||
embed_dim=embed_dim,
|
||||
image_size=vision_cfg.image_size,
|
||||
)
|
||||
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
||||
elif isinstance(vision_cfg.layers, (tuple, list)):
|
||||
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
||||
visual = ModifiedResNet(
|
||||
layers=vision_cfg.layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
image_size=vision_cfg.image_size,
|
||||
width=vision_cfg.width,
|
||||
)
|
||||
else:
|
||||
vision_heads = vision_cfg.width // vision_cfg.head_width
|
||||
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||
visual = VisionTransformer(
|
||||
image_size=vision_cfg.image_size,
|
||||
patch_size=vision_cfg.patch_size,
|
||||
width=vision_cfg.width,
|
||||
layers=vision_cfg.layers,
|
||||
heads=vision_heads,
|
||||
mlp_ratio=vision_cfg.mlp_ratio,
|
||||
ls_init_value=vision_cfg.ls_init_value,
|
||||
patch_dropout=vision_cfg.patch_dropout,
|
||||
input_patchnorm=vision_cfg.input_patchnorm,
|
||||
global_average_pool=vision_cfg.global_average_pool,
|
||||
attentional_pool=vision_cfg.attentional_pool,
|
||||
n_queries=vision_cfg.n_queries,
|
||||
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
||||
output_tokens=vision_cfg.output_tokens,
|
||||
output_dim=embed_dim,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
return visual
|
||||
|
||||
|
||||
def _build_text_tower(
|
||||
embed_dim: int,
|
||||
text_cfg: CLIPTextCfg,
|
||||
quick_gelu: bool = False,
|
||||
cast_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if isinstance(text_cfg, dict):
|
||||
text_cfg = CLIPTextCfg(**text_cfg)
|
||||
|
||||
if text_cfg.hf_model_name:
|
||||
text = HFTextEncoder(
|
||||
text_cfg.hf_model_name,
|
||||
output_dim=embed_dim,
|
||||
proj=text_cfg.proj,
|
||||
pooler_type=text_cfg.pooler_type,
|
||||
pretrained=text_cfg.hf_model_pretrained,
|
||||
output_tokens=text_cfg.output_tokens,
|
||||
)
|
||||
else:
|
||||
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||
|
||||
text = TextTransformer(
|
||||
context_length=text_cfg.context_length,
|
||||
vocab_size=text_cfg.vocab_size,
|
||||
width=text_cfg.width,
|
||||
heads=text_cfg.heads,
|
||||
layers=text_cfg.layers,
|
||||
ls_init_value=text_cfg.ls_init_value,
|
||||
output_dim=embed_dim,
|
||||
embed_cls=text_cfg.embed_cls,
|
||||
output_tokens=text_cfg.output_tokens,
|
||||
pad_id=text_cfg.pad_id,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
output_dict: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
vision_cfg: CLIPVisionCfg,
|
||||
text_cfg: CLIPTextCfg,
|
||||
quick_gelu: bool = False,
|
||||
cast_dtype: Optional[torch.dtype] = None,
|
||||
output_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dict = output_dict
|
||||
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
||||
|
||||
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
||||
self.transformer = text.transformer
|
||||
self.vocab_size = text.vocab_size
|
||||
self.token_embedding = text.token_embedding
|
||||
self.positional_embedding = text.positional_embedding
|
||||
self.ln_final = text.ln_final
|
||||
self.text_projection = text.text_projection
|
||||
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||
|
||||
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||
locked_layers = []
|
||||
locked_layers.append(self.token_embedding)
|
||||
self.positional_embedding.requires_grad = False
|
||||
if unlocked_layers > 0:
|
||||
locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
|
||||
else:
|
||||
locked_layers.append(self.transformer)
|
||||
locked_layers.append(self.ln_final)
|
||||
self.text_projection.requires_grad = False
|
||||
|
||||
# freeze layers
|
||||
for module in locked_layers:
|
||||
for n, p in module.named_parameters():
|
||||
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.visual.set_grad_checkpointing(enable)
|
||||
self.transformer.grad_checkpointing = enable
|
||||
|
||||
def encode_image(self, image, normalize: bool = False):
|
||||
features = self.visual(image)
|
||||
return F.normalize(features, dim=-1) if normalize else features
|
||||
|
||||
def encode_text(self, text, normalize: bool = False):
|
||||
cast_dtype = self.transformer.get_cast_dtype()
|
||||
|
||||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.to(cast_dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
return F.normalize(x, dim=-1) if normalize else x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image, normalize=True)
|
||||
text_features = self.encode_text(text, normalize=True)
|
||||
if self.output_dict:
|
||||
return {
|
||||
"image_features": image_features,
|
||||
"text_features": text_features,
|
||||
"logit_scale": self.logit_scale.exp()
|
||||
}
|
||||
return image_features, text_features, self.logit_scale.exp()
|
||||
|
||||
|
||||
class CustomTextCLIP(nn.Module):
|
||||
output_dict: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
vision_cfg: CLIPVisionCfg,
|
||||
text_cfg: CLIPTextCfg,
|
||||
quick_gelu: bool = False,
|
||||
cast_dtype: Optional[torch.dtype] = None,
|
||||
output_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dict = output_dict
|
||||
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
||||
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||
|
||||
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||
self.text.lock(unlocked_layers, freeze_layer_norm)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.visual.set_grad_checkpointing(enable)
|
||||
self.text.set_grad_checkpointing(enable)
|
||||
|
||||
def encode_image(self, image, normalize: bool = False):
|
||||
features = self.visual(image)
|
||||
return F.normalize(features, dim=-1) if normalize else features
|
||||
|
||||
def encode_text(self, text, normalize: bool = False):
|
||||
features = self.text(text)
|
||||
return F.normalize(features, dim=-1) if normalize else features
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image, normalize=True)
|
||||
text_features = self.encode_text(text, normalize=True)
|
||||
if self.output_dict:
|
||||
return {
|
||||
"image_features": image_features,
|
||||
"text_features": text_features,
|
||||
"logit_scale": self.logit_scale.exp()
|
||||
}
|
||||
return image_features, text_features, self.logit_scale.exp()
|
||||
|
||||
|
||||
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
||||
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
||||
|
||||
def _convert_weights(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.to(dtype)
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.to(dtype)
|
||||
|
||||
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.to(dtype)
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.to(dtype)
|
||||
|
||||
model.apply(_convert_weights)
|
||||
|
||||
|
||||
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
||||
|
||||
|
||||
# used to maintain checkpoint compatibility
|
||||
def convert_to_custom_text_state_dict(state_dict: dict):
|
||||
if 'text_projection' in state_dict:
|
||||
# old format state_dict, move text tower -> .text
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if any(k.startswith(p) for p in (
|
||||
'text_projection',
|
||||
'positional_embedding',
|
||||
'token_embedding',
|
||||
'transformer',
|
||||
'ln_final',
|
||||
)):
|
||||
k = 'text.' + k
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
return state_dict
|
||||
|
||||
|
||||
def build_model_from_openai_state_dict(
|
||||
state_dict: dict,
|
||||
quick_gelu=True,
|
||||
cast_dtype=torch.float16,
|
||||
):
|
||||
vit = "visual.proj" in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len(
|
||||
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
image_size = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [
|
||||
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
image_size = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||
|
||||
vision_cfg = CLIPVisionCfg(
|
||||
layers=vision_layers,
|
||||
width=vision_width,
|
||||
patch_size=vision_patch_size,
|
||||
image_size=image_size,
|
||||
)
|
||||
text_cfg = CLIPTextCfg(
|
||||
context_length=context_length,
|
||||
vocab_size=vocab_size,
|
||||
width=transformer_width,
|
||||
heads=transformer_heads,
|
||||
layers=transformer_layers,
|
||||
)
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
vision_cfg=vision_cfg,
|
||||
text_cfg=text_cfg,
|
||||
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
||||
cast_dtype=cast_dtype,
|
||||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
state_dict.pop(key, None)
|
||||
|
||||
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
||||
model.load_state_dict(state_dict)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
||||
model.eval()
|
||||
image_size = model.visual.image_size
|
||||
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
||||
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
||||
model = torch.jit.trace_module(
|
||||
model,
|
||||
inputs=dict(
|
||||
forward=(example_images, example_text),
|
||||
encode_text=(example_text,),
|
||||
encode_image=(example_images,)
|
||||
))
|
||||
model.visual.image_size = image_size
|
||||
return model
|
||||
|
||||
|
||||
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
||||
# Rescale the grid of position embeddings when loading from state_dict
|
||||
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
||||
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
||||
return
|
||||
grid_size = to_2tuple(model.visual.grid_size)
|
||||
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
||||
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
||||
if new_seq_len == old_pos_embed.shape[0]:
|
||||
return
|
||||
|
||||
if extra_tokens:
|
||||
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
||||
else:
|
||||
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
||||
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
||||
|
||||
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
||||
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
pos_emb_img = F.interpolate(
|
||||
pos_emb_img,
|
||||
size=grid_size,
|
||||
mode=interpolation,
|
||||
antialias=antialias,
|
||||
align_corners=False,
|
||||
)
|
||||
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
||||
if pos_emb_tok is not None:
|
||||
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
||||
else:
|
||||
new_pos_embed = pos_emb_img
|
||||
state_dict['visual.positional_embedding'] = new_pos_embed
|
||||
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 32,
|
||||
"width": 1280,
|
||||
"head_width": 80,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 24
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .utils import freeze_batch_norm_2d
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.act1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.act3 = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.act1(self.bn1(self.conv1(x)))
|
||||
out = self.act2(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.act3(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x, key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0.,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.image_size = image_size
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.act1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.act3 = nn.ReLU(inplace=True)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
||||
|
||||
self.init_parameters()
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def init_parameters(self):
|
||||
if self.attnpool is not None:
|
||||
std = self.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
if freeze_bn_stats:
|
||||
freeze_batch_norm_2d(self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
# FIXME support for non-transformer
|
||||
pass
|
||||
|
||||
def stem(self, x):
|
||||
x = self.act1(self.bn1(self.conv1(x)))
|
||||
x = self.act2(self.bn2(self.conv2(x)))
|
||||
x = self.act3(self.bn3(self.conv3(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
144
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
Normal file
144
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
Normal file
@@ -0,0 +1,144 @@
|
||||
""" OpenAI pretrained model functions
|
||||
|
||||
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
||||
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
||||
|
||||
__all__ = ["list_openai_models", "load_openai_model"]
|
||||
|
||||
|
||||
def list_openai_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list_pretrained_models_by_tag('openai')
|
||||
|
||||
|
||||
def load_openai_model(
|
||||
name: str,
|
||||
precision: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
jit: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
precision: str
|
||||
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||
cache_dir : Optional[str]
|
||||
The directory to cache the downloaded model weights
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if precision is None:
|
||||
precision = 'fp32' if device == 'cpu' else 'fp16'
|
||||
|
||||
if get_pretrained_url(name, 'openai'):
|
||||
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
# Build a non-jit model from the OpenAI jitted model state dict
|
||||
cast_dtype = get_cast_dtype(precision)
|
||||
try:
|
||||
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
||||
except KeyError:
|
||||
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
||||
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
||||
|
||||
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
||||
model = model.to(device)
|
||||
if precision.startswith('amp') or precision == 'fp32':
|
||||
model.float()
|
||||
elif precision == 'bf16':
|
||||
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
||||
|
||||
return model
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def patch_device(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 (typically for CPU)
|
||||
if precision == 'fp32':
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if inputs[i].node()["value"] == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
model.float()
|
||||
|
||||
# ensure image_size attr available at consistent location for both jit and non-jit
|
||||
model.visual.image_size = model.input_resolution.item()
|
||||
return model
|
||||
376
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
Normal file
376
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .version import __version__
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
|
||||
_has_hf_hub = True
|
||||
except ImportError:
|
||||
hf_hub_download = None
|
||||
_has_hf_hub = False
|
||||
|
||||
|
||||
def _pcfg(url='', hf_hub='', mean=None, std=None):
|
||||
return dict(
|
||||
url=url,
|
||||
hf_hub=hf_hub,
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
|
||||
|
||||
_RN50 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
||||
yfcc15m=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
||||
cc12m=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
||||
)
|
||||
|
||||
_RN50_quickgelu = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
||||
yfcc15m=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
||||
cc12m=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
||||
)
|
||||
|
||||
_RN101 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
||||
yfcc15m=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
||||
)
|
||||
|
||||
_RN101_quickgelu = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
||||
yfcc15m=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
||||
)
|
||||
|
||||
_RN50x4 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
|
||||
)
|
||||
|
||||
_RN50x16 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
|
||||
)
|
||||
|
||||
_RN50x64 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
|
||||
)
|
||||
|
||||
_VITB32 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
||||
laion400m_e31=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
||||
laion400m_e32=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
||||
laion2b_e16=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
||||
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
|
||||
)
|
||||
|
||||
_VITB32_quickgelu = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
||||
laion400m_e31=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
||||
laion400m_e32=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
||||
)
|
||||
|
||||
_VITB16 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
||||
laion400m_e31=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
||||
laion400m_e32=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
||||
# laion400m_32k=_pcfg(
|
||||
# url="",
|
||||
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
# laion400m_64k=_pcfg(
|
||||
# url="",
|
||||
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
||||
)
|
||||
|
||||
_VITB16_PLUS_240 = dict(
|
||||
laion400m_e31=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
||||
laion400m_e32=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
||||
)
|
||||
|
||||
_VITL14 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
||||
laion400m_e31=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
||||
laion400m_e32=_pcfg(
|
||||
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
||||
laion2b_s32b_b82k=_pcfg(
|
||||
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
)
|
||||
|
||||
_VITL14_336 = dict(
|
||||
openai=_pcfg(
|
||||
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
||||
)
|
||||
|
||||
_VITH14 = dict(
|
||||
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
||||
)
|
||||
|
||||
_VITg14 = dict(
|
||||
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
||||
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
||||
)
|
||||
|
||||
_VITbigG14 = dict(
|
||||
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
||||
)
|
||||
|
||||
_robertaViTB32 = dict(
|
||||
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
|
||||
)
|
||||
|
||||
_xlmRobertaBaseViTB32 = dict(
|
||||
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
|
||||
)
|
||||
|
||||
_xlmRobertaLargeFrozenViTH14 = dict(
|
||||
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
|
||||
)
|
||||
|
||||
_convnext_base = dict(
|
||||
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
|
||||
)
|
||||
|
||||
_convnext_base_w = dict(
|
||||
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
|
||||
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
|
||||
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
|
||||
)
|
||||
|
||||
_convnext_base_w_320 = dict(
|
||||
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
|
||||
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
|
||||
)
|
||||
|
||||
_convnext_large_d = dict(
|
||||
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
|
||||
)
|
||||
|
||||
_convnext_large_d_320 = dict(
|
||||
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
|
||||
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
|
||||
)
|
||||
|
||||
_convnext_xxlarge = dict(
|
||||
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
|
||||
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
|
||||
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
|
||||
)
|
||||
|
||||
_coca_VITB32 = dict(
|
||||
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
|
||||
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
|
||||
)
|
||||
|
||||
_coca_VITL14 = dict(
|
||||
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
|
||||
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
|
||||
)
|
||||
|
||||
|
||||
_PRETRAINED = {
|
||||
"RN50": _RN50,
|
||||
"RN50-quickgelu": _RN50_quickgelu,
|
||||
"RN101": _RN101,
|
||||
"RN101-quickgelu": _RN101_quickgelu,
|
||||
"RN50x4": _RN50x4,
|
||||
"RN50x16": _RN50x16,
|
||||
"RN50x64": _RN50x64,
|
||||
"ViT-B-32": _VITB32,
|
||||
"ViT-B-32-quickgelu": _VITB32_quickgelu,
|
||||
"ViT-B-16": _VITB16,
|
||||
"ViT-B-16-plus-240": _VITB16_PLUS_240,
|
||||
"ViT-L-14": _VITL14,
|
||||
"ViT-L-14-336": _VITL14_336,
|
||||
"ViT-H-14": _VITH14,
|
||||
"ViT-g-14": _VITg14,
|
||||
"ViT-bigG-14": _VITbigG14,
|
||||
"roberta-ViT-B-32": _robertaViTB32,
|
||||
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
|
||||
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
|
||||
"convnext_base": _convnext_base,
|
||||
"convnext_base_w": _convnext_base_w,
|
||||
"convnext_base_w_320": _convnext_base_w_320,
|
||||
"convnext_large_d": _convnext_large_d,
|
||||
"convnext_large_d_320": _convnext_large_d_320,
|
||||
"convnext_xxlarge": _convnext_xxlarge,
|
||||
"coca_ViT-B-32": _coca_VITB32,
|
||||
"coca_ViT-L-14": _coca_VITL14,
|
||||
}
|
||||
|
||||
|
||||
def _clean_tag(tag: str):
|
||||
# normalize pretrained tags
|
||||
return tag.lower().replace('-', '_')
|
||||
|
||||
|
||||
def list_pretrained(as_str: bool = False):
|
||||
""" returns list of pretrained models
|
||||
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
||||
"""
|
||||
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
||||
|
||||
|
||||
def list_pretrained_models_by_tag(tag: str):
|
||||
""" return all models having the specified pretrain tag """
|
||||
models = []
|
||||
tag = _clean_tag(tag)
|
||||
for k in _PRETRAINED.keys():
|
||||
if tag in _PRETRAINED[k]:
|
||||
models.append(k)
|
||||
return models
|
||||
|
||||
|
||||
def list_pretrained_tags_by_model(model: str):
|
||||
""" return all pretrain tags for the specified model architecture """
|
||||
tags = []
|
||||
if model in _PRETRAINED:
|
||||
tags.extend(_PRETRAINED[model].keys())
|
||||
return tags
|
||||
|
||||
|
||||
def is_pretrained_cfg(model: str, tag: str):
|
||||
if model not in _PRETRAINED:
|
||||
return False
|
||||
return _clean_tag(tag) in _PRETRAINED[model]
|
||||
|
||||
|
||||
def get_pretrained_cfg(model: str, tag: str):
|
||||
if model not in _PRETRAINED:
|
||||
return {}
|
||||
model_pretrained = _PRETRAINED[model]
|
||||
return model_pretrained.get(_clean_tag(tag), {})
|
||||
|
||||
|
||||
def get_pretrained_url(model: str, tag: str):
|
||||
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
||||
return cfg.get('url', '')
|
||||
|
||||
|
||||
def download_pretrained_from_url(
|
||||
url: str,
|
||||
cache_dir: Union[str, None] = None,
|
||||
):
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.expanduser("~/.cache/clip")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
if 'openaipublic' in url:
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
elif 'mlfoundations' in url:
|
||||
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
||||
else:
|
||||
expected_sha256 = ''
|
||||
|
||||
download_target = os.path.join(cache_dir, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if expected_sha256:
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
else:
|
||||
return download_target
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
||||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||
return _has_hf_hub
|
||||
|
||||
|
||||
def download_pretrained_from_hf(
|
||||
model_id: str,
|
||||
filename: str = 'open_clip_pytorch_model.bin',
|
||||
revision=None,
|
||||
cache_dir: Union[str, None] = None,
|
||||
):
|
||||
has_hf_hub(True)
|
||||
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
||||
return cached_file
|
||||
|
||||
|
||||
def download_pretrained(
|
||||
cfg: Dict,
|
||||
force_hf_hub: bool = False,
|
||||
cache_dir: Union[str, None] = None,
|
||||
):
|
||||
target = ''
|
||||
if not cfg:
|
||||
return target
|
||||
|
||||
download_url = cfg.get('url', '')
|
||||
download_hf_hub = cfg.get('hf_hub', '')
|
||||
if download_hf_hub and force_hf_hub:
|
||||
# use HF hub even if url exists
|
||||
download_url = ''
|
||||
|
||||
if download_url:
|
||||
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
||||
elif download_hf_hub:
|
||||
has_hf_hub(True)
|
||||
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
||||
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
||||
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
||||
model_id, filename = os.path.split(download_hf_hub)
|
||||
if filename:
|
||||
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
||||
else:
|
||||
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
||||
|
||||
return target
|
||||
@@ -0,0 +1,243 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from huggingface_hub import (
|
||||
create_repo,
|
||||
get_hf_file_metadata,
|
||||
hf_hub_download,
|
||||
hf_hub_url,
|
||||
repo_type_and_id_from_hf_id,
|
||||
upload_folder,
|
||||
)
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
_has_hf_hub = True
|
||||
except ImportError:
|
||||
_has_hf_hub = False
|
||||
|
||||
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
|
||||
from .tokenizer import HFTokenizer
|
||||
|
||||
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
config_path: str,
|
||||
model_config: Optional[dict]
|
||||
):
|
||||
preprocess_cfg = {
|
||||
'mean': model.visual.image_mean,
|
||||
'std': model.visual.image_std,
|
||||
}
|
||||
hf_config = {
|
||||
'model_cfg': model_config,
|
||||
'preprocess_cfg': preprocess_cfg,
|
||||
}
|
||||
|
||||
with config_path.open('w') as f:
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
|
||||
def save_for_hf(
|
||||
model,
|
||||
tokenizer: HFTokenizer,
|
||||
model_config: dict,
|
||||
save_directory: str,
|
||||
weights_filename='open_clip_pytorch_model.bin',
|
||||
config_filename='open_clip_config.json',
|
||||
):
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
weights_path = save_directory / weights_filename
|
||||
torch.save(model.state_dict(), weights_path)
|
||||
|
||||
tokenizer.save_pretrained(save_directory)
|
||||
|
||||
config_path = save_directory / config_filename
|
||||
save_config_for_hf(model, config_path, model_config=model_config)
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
tokenizer,
|
||||
model_config: Optional[dict],
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_card: Optional[dict] = None,
|
||||
):
|
||||
if not isinstance(tokenizer, HFTokenizer):
|
||||
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
|
||||
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
|
||||
|
||||
# Create repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
# Infer complete repo_id from repo_url
|
||||
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||
repo_id = f"{repo_owner}/{repo_name}"
|
||||
|
||||
# Check if README file already exist in repo
|
||||
try:
|
||||
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||
has_readme = True
|
||||
except EntryNotFoundError:
|
||||
has_readme = False
|
||||
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(
|
||||
model,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
save_directory=tmpdir,
|
||||
)
|
||||
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
model_card = model_card or {}
|
||||
model_name = repo_id.split('/')[-1]
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = generate_readme(model_card, model_name)
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
# Upload model and return
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=tmpdir,
|
||||
revision=revision,
|
||||
create_pr=create_pr,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
def push_pretrained_to_hf_hub(
|
||||
model_name,
|
||||
pretrained: str,
|
||||
repo_id: str,
|
||||
image_mean: Optional[Tuple[float, ...]] = None,
|
||||
image_std: Optional[Tuple[float, ...]] = None,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_card: Optional[dict] = None,
|
||||
):
|
||||
model, preprocess_eval = create_model_from_pretrained(
|
||||
model_name,
|
||||
pretrained=pretrained,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
)
|
||||
|
||||
model_config = get_model_config(model_name)
|
||||
assert model_config
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
|
||||
push_to_hf_hub(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
repo_id=repo_id,
|
||||
commit_message=commit_message,
|
||||
token=token,
|
||||
revision=revision,
|
||||
private=private,
|
||||
create_pr=create_pr,
|
||||
model_card=model_card,
|
||||
)
|
||||
|
||||
|
||||
def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text = "---\n"
|
||||
readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
|
||||
readme_text += "library_tag: open_clip\n"
|
||||
readme_text += f"license: {model_card.get('license', 'mit')}\n"
|
||||
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||
readme_text += 'datasets:\n'
|
||||
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||
readme_text += "---\n"
|
||||
readme_text += f"# Model card for {model_name}\n"
|
||||
if 'description' in model_card:
|
||||
readme_text += f"\n{model_card['description']}\n"
|
||||
if 'details' in model_card:
|
||||
readme_text += f"\n## Model Details\n"
|
||||
for k, v in model_card['details'].items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
readme_text += f"- **{k}:**\n"
|
||||
for vi in v:
|
||||
readme_text += f" - {vi}\n"
|
||||
elif isinstance(v, dict):
|
||||
readme_text += f"- **{k}:**\n"
|
||||
for ki, vi in v.items():
|
||||
readme_text += f" - {ki}: {vi}\n"
|
||||
else:
|
||||
readme_text += f"- **{k}:** {v}\n"
|
||||
if 'usage' in model_card:
|
||||
readme_text += f"\n## Model Usage\n"
|
||||
readme_text += model_card['usage']
|
||||
readme_text += '\n'
|
||||
|
||||
if 'comparison' in model_card:
|
||||
readme_text += f"\n## Model Comparison\n"
|
||||
readme_text += model_card['comparison']
|
||||
readme_text += '\n'
|
||||
|
||||
if 'citation' in model_card:
|
||||
readme_text += f"\n## Citation\n"
|
||||
if not isinstance(model_card['citation'], (list, tuple)):
|
||||
citations = [model_card['citation']]
|
||||
else:
|
||||
citations = model_card['citation']
|
||||
for c in citations:
|
||||
readme_text += f"```bibtex\n{c}\n```\n"
|
||||
|
||||
return readme_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
|
||||
parser.add_argument(
|
||||
"--model", type=str, help="Name of the model to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained", type=str,
|
||||
help="Use a pretrained CLIP model weights with the specified tag or file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id", type=str,
|
||||
help="Destination HF Hub repo-id ie 'organization/model_id'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override default image mean value of dataset')
|
||||
parser.add_argument(
|
||||
'--image-std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override default image std deviation of of dataset')
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
|
||||
|
||||
# FIXME add support to pass model_card json / template from file via cmd line
|
||||
|
||||
push_pretrained_to_hf_hub(
|
||||
args.model,
|
||||
args.pretrained,
|
||||
args.repo_id,
|
||||
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
|
||||
image_std=args.image_std,
|
||||
)
|
||||
|
||||
print(f'{args.model} saved.')
|
||||
127
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
Normal file
127
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
Normal file
@@ -0,0 +1,127 @@
|
||||
""" timm model adapter
|
||||
|
||||
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
||||
"""
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
import timm
|
||||
from timm.models.layers import Mlp, to_2tuple
|
||||
try:
|
||||
# old timm imports < 0.8.1
|
||||
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
||||
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
||||
except ImportError:
|
||||
# new timm imports >= 0.8.1
|
||||
from timm.layers import RotAttentionPool2d
|
||||
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
||||
except ImportError:
|
||||
timm = None
|
||||
|
||||
from .utils import freeze_batch_norm_2d
|
||||
|
||||
|
||||
class TimmModel(nn.Module):
|
||||
""" timm model adapter
|
||||
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
embed_dim,
|
||||
image_size=224,
|
||||
pool='avg',
|
||||
proj='linear',
|
||||
proj_bias=False,
|
||||
drop=0.,
|
||||
drop_path=None,
|
||||
pretrained=False,
|
||||
):
|
||||
super().__init__()
|
||||
if timm is None:
|
||||
raise RuntimeError("Please `pip install timm` to use timm models.")
|
||||
|
||||
self.image_size = to_2tuple(image_size)
|
||||
timm_kwargs = {}
|
||||
if drop_path is not None:
|
||||
timm_kwargs['drop_path_rate'] = drop_path
|
||||
self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
|
||||
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
||||
feature_ndim = 1 if not feat_size else 2
|
||||
if pool in ('abs_attn', 'rot_attn'):
|
||||
assert feature_ndim == 2
|
||||
# if attn pooling used, remove both classifier and default pool
|
||||
self.trunk.reset_classifier(0, global_pool='')
|
||||
else:
|
||||
# reset global pool if pool config set, otherwise leave as network default
|
||||
reset_kwargs = dict(global_pool=pool) if pool else {}
|
||||
self.trunk.reset_classifier(0, **reset_kwargs)
|
||||
prev_chs = self.trunk.num_features
|
||||
|
||||
head_layers = OrderedDict()
|
||||
if pool == 'abs_attn':
|
||||
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
||||
prev_chs = embed_dim
|
||||
elif pool == 'rot_attn':
|
||||
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
||||
prev_chs = embed_dim
|
||||
else:
|
||||
assert proj, 'projection layer needed if non-attention pooling is used.'
|
||||
|
||||
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
||||
if proj == 'linear':
|
||||
head_layers['drop'] = nn.Dropout(drop)
|
||||
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
||||
elif proj == 'mlp':
|
||||
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
|
||||
|
||||
self.head = nn.Sequential(head_layers)
|
||||
|
||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
""" lock modules
|
||||
Args:
|
||||
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
||||
"""
|
||||
if not unlocked_groups:
|
||||
# lock full model
|
||||
for param in self.trunk.parameters():
|
||||
param.requires_grad = False
|
||||
if freeze_bn_stats:
|
||||
freeze_batch_norm_2d(self.trunk)
|
||||
else:
|
||||
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
||||
try:
|
||||
# FIXME import here until API stable and in an official release
|
||||
from timm.models.helpers import group_parameters, group_modules
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
||||
matcher = self.trunk.group_matcher()
|
||||
gparams = group_parameters(self.trunk, matcher)
|
||||
max_layer_id = max(gparams.keys())
|
||||
max_layer_id = max_layer_id - unlocked_groups
|
||||
for group_idx in range(max_layer_id + 1):
|
||||
group = gparams[group_idx]
|
||||
for param in group:
|
||||
self.trunk.get_parameter(param).requires_grad = False
|
||||
if freeze_bn_stats:
|
||||
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
||||
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
||||
freeze_batch_norm_2d(self.trunk, gmodules)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
try:
|
||||
self.trunk.set_grad_checkpointing(enable)
|
||||
except Exception as e:
|
||||
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.trunk(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
211
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
Normal file
211
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
Normal file
@@ -0,0 +1,211 @@
|
||||
""" CLIP tokenizer
|
||||
|
||||
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||
"""
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union, List
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
# https://stackoverflow.com/q/62691279
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
if not special_tokens:
|
||||
special_tokens = ['<start_of_text>', '<end_of_text>']
|
||||
else:
|
||||
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
||||
vocab.extend(special_tokens)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {t:t for t in special_tokens}
|
||||
special = "|".join(special_tokens)
|
||||
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
self.vocab_size = len(self.encoder)
|
||||
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = self.encoder["<start_of_text>"]
|
||||
eot_token = self.encoder["<end_of_text>"]
|
||||
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
tokens = tokens[:context_length] # Truncate
|
||||
tokens[-1] = eot_token
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
class HFTokenizer:
|
||||
"""HuggingFace tokenizer wrapper"""
|
||||
|
||||
def __init__(self, tokenizer_name: str):
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
def save_pretrained(self, dest):
|
||||
self.tokenizer.save_pretrained(dest)
|
||||
|
||||
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
||||
# same cleaning as for default tokenizer, except lowercasing
|
||||
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
||||
input_ids = self.tokenizer(
|
||||
texts,
|
||||
return_tensors='pt',
|
||||
max_length=context_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
).input_ids
|
||||
return input_ids
|
||||
216
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
Normal file
216
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as F
|
||||
from functools import partial
|
||||
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
||||
CenterCrop
|
||||
|
||||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationCfg:
|
||||
scale: Tuple[float, float] = (0.9, 1.0)
|
||||
ratio: Optional[Tuple[float, float]] = None
|
||||
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
||||
interpolation: Optional[str] = None
|
||||
re_prob: Optional[float] = None
|
||||
re_count: Optional[int] = None
|
||||
use_timm: bool = False
|
||||
|
||||
|
||||
class ResizeMaxSize(nn.Module):
|
||||
|
||||
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
||||
super().__init__()
|
||||
if not isinstance(max_size, int):
|
||||
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
||||
self.max_size = max_size
|
||||
self.interpolation = interpolation
|
||||
self.fn = min if fn == 'min' else min
|
||||
self.fill = fill
|
||||
|
||||
def forward(self, img):
|
||||
if isinstance(img, torch.Tensor):
|
||||
height, width = img.shape[1:]
|
||||
else:
|
||||
width, height = img.size
|
||||
scale = self.max_size / float(max(height, width))
|
||||
if scale != 1.0:
|
||||
new_size = tuple(round(dim * scale) for dim in (height, width))
|
||||
img = F.resize(img, new_size, self.interpolation)
|
||||
pad_h = self.max_size - new_size[0]
|
||||
pad_w = self.max_size - new_size[1]
|
||||
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
||||
return img
|
||||
|
||||
|
||||
def _convert_to_rgb_or_rgba(image):
|
||||
if image.mode == 'RGBA':
|
||||
return image
|
||||
else:
|
||||
return image.convert('RGB')
|
||||
|
||||
# def transform_and_split(merged, transform_fn, normalize_fn):
|
||||
# transformed = transform_fn(merged)
|
||||
# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
|
||||
|
||||
# # crop_img = _convert_to_rgb(crop_img)
|
||||
# crop_img = normalize_fn(ToTensor()(crop_img))
|
||||
# return crop_img, crop_label
|
||||
|
||||
class MaskAwareNormalize(nn.Module):
|
||||
def __init__(self, mean, std):
|
||||
super().__init__()
|
||||
self.normalize = Normalize(mean=mean, std=std)
|
||||
|
||||
def forward(self, tensor):
|
||||
if tensor.shape[0] == 4:
|
||||
return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
|
||||
else:
|
||||
return self.normalize(tensor)
|
||||
|
||||
def image_transform(
|
||||
image_size: int,
|
||||
is_train: bool,
|
||||
mean: Optional[Tuple[float, ...]] = None,
|
||||
std: Optional[Tuple[float, ...]] = None,
|
||||
resize_longest_max: bool = False,
|
||||
fill_color: int = 0,
|
||||
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||
):
|
||||
mean = mean or OPENAI_DATASET_MEAN
|
||||
if not isinstance(mean, (list, tuple)):
|
||||
mean = (mean,) * 3
|
||||
|
||||
std = std or OPENAI_DATASET_STD
|
||||
if not isinstance(std, (list, tuple)):
|
||||
std = (std,) * 3
|
||||
|
||||
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
||||
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
||||
image_size = image_size[0]
|
||||
|
||||
if isinstance(aug_cfg, dict):
|
||||
aug_cfg = AugmentationCfg(**aug_cfg)
|
||||
else:
|
||||
aug_cfg = aug_cfg or AugmentationCfg()
|
||||
normalize = MaskAwareNormalize(mean=mean, std=std)
|
||||
if is_train:
|
||||
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
||||
use_timm = aug_cfg_dict.pop('use_timm', False)
|
||||
if use_timm:
|
||||
assert False, "not tested for augmentation with mask"
|
||||
from timm.data import create_transform # timm can still be optional
|
||||
if isinstance(image_size, (tuple, list)):
|
||||
assert len(image_size) >= 2
|
||||
input_size = (3,) + image_size[-2:]
|
||||
else:
|
||||
input_size = (3, image_size, image_size)
|
||||
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
||||
aug_cfg_dict.setdefault('interpolation', 'random')
|
||||
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
||||
train_transform = create_transform(
|
||||
input_size=input_size,
|
||||
is_training=True,
|
||||
hflip=0.,
|
||||
mean=mean,
|
||||
std=std,
|
||||
re_mode='pixel',
|
||||
**aug_cfg_dict,
|
||||
)
|
||||
else:
|
||||
train_transform = Compose([
|
||||
_convert_to_rgb_or_rgba,
|
||||
ToTensor(),
|
||||
RandomResizedCrop(
|
||||
image_size,
|
||||
scale=aug_cfg_dict.pop('scale'),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
normalize,
|
||||
])
|
||||
if aug_cfg_dict:
|
||||
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
||||
return train_transform
|
||||
else:
|
||||
transforms = [
|
||||
_convert_to_rgb_or_rgba,
|
||||
ToTensor(),
|
||||
]
|
||||
if resize_longest_max:
|
||||
transforms.extend([
|
||||
ResizeMaxSize(image_size, fill=fill_color)
|
||||
])
|
||||
else:
|
||||
transforms.extend([
|
||||
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
||||
CenterCrop(image_size),
|
||||
])
|
||||
transforms.extend([
|
||||
normalize,
|
||||
])
|
||||
return Compose(transforms)
|
||||
|
||||
|
||||
# def image_transform_region(
|
||||
# image_size: int,
|
||||
# is_train: bool,
|
||||
# mean: Optional[Tuple[float, ...]] = None,
|
||||
# std: Optional[Tuple[float, ...]] = None,
|
||||
# resize_longest_max: bool = False,
|
||||
# fill_color: int = 0,
|
||||
# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||
# ):
|
||||
# mean = mean or OPENAI_DATASET_MEAN
|
||||
# if not isinstance(mean, (list, tuple)):
|
||||
# mean = (mean,) * 3
|
||||
|
||||
# std = std or OPENAI_DATASET_STD
|
||||
# if not isinstance(std, (list, tuple)):
|
||||
# std = (std,) * 3
|
||||
|
||||
# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
||||
# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
||||
# image_size = image_size[0]
|
||||
|
||||
# if isinstance(aug_cfg, dict):
|
||||
# aug_cfg = AugmentationCfg(**aug_cfg)
|
||||
# else:
|
||||
# aug_cfg = aug_cfg or AugmentationCfg()
|
||||
# normalize = Normalize(mean=mean, std=std)
|
||||
# if is_train:
|
||||
# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
||||
|
||||
# transform = Compose([
|
||||
# RandomResizedCrop(
|
||||
# image_size,
|
||||
# scale=aug_cfg_dict.pop('scale'),
|
||||
# interpolation=InterpolationMode.BICUBIC,
|
||||
# ),
|
||||
# ])
|
||||
# train_transform = Compose([
|
||||
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
|
||||
# ])
|
||||
# return train_transform
|
||||
# else:
|
||||
# if resize_longest_max:
|
||||
# transform = [
|
||||
# ResizeMaxSize(image_size, fill=fill_color)
|
||||
# ]
|
||||
# val_transform = Compose([
|
||||
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
||||
# ])
|
||||
# else:
|
||||
# transform = [
|
||||
# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
||||
# CenterCrop(image_size),
|
||||
# ]
|
||||
# val_transform = Compose([
|
||||
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
||||
# ])
|
||||
# return val_transform
|
||||
727
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
Normal file
727
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
Normal file
@@ -0,0 +1,727 @@
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from .utils import to_2tuple
|
||||
|
||||
|
||||
class LayerNormFp32(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x.to(orig_type)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x.to(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class PatchDropout(nn.Module):
|
||||
"""
|
||||
https://arxiv.org/abs/2212.00794
|
||||
"""
|
||||
|
||||
def __init__(self, prob, exclude_first_token=True):
|
||||
super().__init__()
|
||||
assert 0 <= prob < 1.
|
||||
self.prob = prob
|
||||
self.exclude_first_token = exclude_first_token # exclude CLS token
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or self.prob == 0.:
|
||||
return x
|
||||
|
||||
if self.exclude_first_token:
|
||||
cls_tokens, x = x[:, :1], x[:, 1:]
|
||||
else:
|
||||
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
||||
|
||||
batch = x.size()[0]
|
||||
num_tokens = x.size()[1]
|
||||
|
||||
batch_indices = torch.arange(batch)
|
||||
batch_indices = batch_indices[..., None]
|
||||
|
||||
keep_prob = 1 - self.prob
|
||||
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
||||
|
||||
rand = torch.randn(batch, num_tokens)
|
||||
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
||||
|
||||
x = x[batch_indices, patch_indices_keep]
|
||||
|
||||
if self.exclude_first_token:
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
scaled_cosine=False,
|
||||
scale_heads=False,
|
||||
logit_scale_max=math.log(1. / 0.01),
|
||||
attn_drop=0.,
|
||||
proj_drop=0.
|
||||
):
|
||||
super().__init__()
|
||||
self.scaled_cosine = scaled_cosine
|
||||
self.scale_heads = scale_heads
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.logit_scale_max = logit_scale_max
|
||||
|
||||
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
||||
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
||||
if qkv_bias:
|
||||
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
||||
else:
|
||||
self.in_proj_bias = None
|
||||
|
||||
if self.scaled_cosine:
|
||||
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
||||
else:
|
||||
self.logit_scale = None
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
if self.scale_heads:
|
||||
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
||||
else:
|
||||
self.head_scale = None
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
self.out_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||
L, N, C = x.shape
|
||||
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
||||
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||
|
||||
if self.logit_scale is not None:
|
||||
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
||||
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
||||
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
||||
attn = attn.view(-1, L, L)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = torch.bmm(q, k.transpose(-1, -2))
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
||||
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
||||
attn_mask = new_attn_mask
|
||||
attn += attn_mask
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = torch.bmm(attn, v)
|
||||
if self.head_scale is not None:
|
||||
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
||||
x = x.view(-1, L, C)
|
||||
x = x.transpose(0, 1).reshape(L, N, C)
|
||||
x = self.out_proj(x)
|
||||
x = self.out_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionalPooler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
context_dim: int,
|
||||
n_head: int = 8,
|
||||
n_queries: int = 256,
|
||||
norm_layer: Callable = LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
|
||||
self.ln_q = norm_layer(d_model)
|
||||
self.ln_k = norm_layer(context_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
||||
N = x.shape[1]
|
||||
q = self.ln_q(self.query)
|
||||
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
|
||||
return out.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
def _repeat(self, query, N: int):
|
||||
return query.unsqueeze(1).repeat(1, N, 1)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
is_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
if is_cross_attention:
|
||||
self.ln_1_kv = norm_layer(d_model)
|
||||
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||
("gelu", act_layer()),
|
||||
("c_proj", nn.Linear(mlp_width, d_model))
|
||||
]))
|
||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
|
||||
def attention(
|
||||
self,
|
||||
q_x: torch.Tensor,
|
||||
k_x: Optional[torch.Tensor] = None,
|
||||
v_x: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
k_x = k_x if k_x is not None else q_x
|
||||
v_x = v_x if v_x is not None else q_x
|
||||
|
||||
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
||||
return self.attn(
|
||||
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
|
||||
)[0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q_x: torch.Tensor,
|
||||
k_x: Optional[torch.Tensor] = None,
|
||||
v_x: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
||||
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
||||
|
||||
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
||||
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class CustomResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
scale_cosine_attn: bool = False,
|
||||
scale_heads: bool = False,
|
||||
scale_attn: bool = False,
|
||||
scale_fc: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.attn = Attention(
|
||||
d_model, n_head,
|
||||
scaled_cosine=scale_cosine_attn,
|
||||
scale_heads=scale_heads,
|
||||
)
|
||||
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
||||
("gelu", act_layer()),
|
||||
("c_proj", nn.Linear(mlp_width, d_model))
|
||||
]))
|
||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
||||
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.resblocks = nn.ModuleList([
|
||||
ResidualAttentionBlock(
|
||||
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
def get_cast_dtype(self) -> torch.dtype:
|
||||
return self.resblocks[0].mlp.c_fc.weight.dtype
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||
for r in self.resblocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||
x = checkpoint(r, x, None, None, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
output_tokens: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
ls_init_value: float = None,
|
||||
global_average_pool: bool = False,
|
||||
attentional_pool: bool = False,
|
||||
n_queries: int = 256,
|
||||
attn_pooler_heads: int = 8,
|
||||
output_dim: int = 512,
|
||||
patch_dropout: float = 0.,
|
||||
input_patchnorm: bool = False,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
output_tokens: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.output_tokens = output_tokens
|
||||
image_height, image_width = self.image_size = to_2tuple(image_size)
|
||||
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
||||
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
||||
self.output_dim = output_dim
|
||||
|
||||
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
|
||||
self.input_patchnorm = input_patchnorm
|
||||
|
||||
if input_patchnorm:
|
||||
patch_input_dim = patch_height * patch_width * 3
|
||||
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
|
||||
self.conv1 = nn.Linear(patch_input_dim, width)
|
||||
else:
|
||||
self.patchnorm_pre_ln = nn.Identity()
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
||||
|
||||
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
||||
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
||||
|
||||
self.ln_pre = norm_layer(width)
|
||||
self.transformer = Transformer(
|
||||
width,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
self.global_average_pool = global_average_pool
|
||||
if attentional_pool:
|
||||
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
|
||||
self.ln_post = norm_layer(output_dim)
|
||||
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
|
||||
else:
|
||||
self.attn_pool = None
|
||||
self.ln_post = norm_layer(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
self.init_parameters()
|
||||
|
||||
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if unlocked_groups != 0:
|
||||
groups = [
|
||||
[
|
||||
self.conv1,
|
||||
self.class_embedding,
|
||||
self.positional_embedding,
|
||||
self.ln_pre,
|
||||
],
|
||||
*self.transformer.resblocks[:-1],
|
||||
[
|
||||
self.transformer.resblocks[-1],
|
||||
self.ln_post,
|
||||
],
|
||||
self.proj,
|
||||
]
|
||||
|
||||
def _unlock(x):
|
||||
if isinstance(x, Sequence):
|
||||
for g in x:
|
||||
_unlock(g)
|
||||
else:
|
||||
if isinstance(x, torch.nn.Parameter):
|
||||
x.requires_grad = True
|
||||
else:
|
||||
for p in x.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
_unlock(groups[-unlocked_groups:])
|
||||
|
||||
def init_parameters(self):
|
||||
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
|
||||
# TODO experiment if default PyTorch init, below, or alternate init is best.
|
||||
|
||||
# nn.init.normal_(self.class_embedding, std=self.scale)
|
||||
# nn.init.normal_(self.positional_embedding, std=self.scale)
|
||||
#
|
||||
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
# attn_std = self.transformer.width ** -0.5
|
||||
# fc_std = (2 * self.transformer.width) ** -0.5
|
||||
# for block in self.transformer.resblocks:
|
||||
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
#
|
||||
# if self.text_projection is not None:
|
||||
# nn.init.normal_(self.text_projection, std=self.scale)
|
||||
pass
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.transformer.grad_checkpointing = enable
|
||||
|
||||
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.global_average_pool:
|
||||
return x.mean(dim=1), x
|
||||
else:
|
||||
return x[:, 0], x[:, 1:]
|
||||
|
||||
def forward(self, x: torch.Tensor, skip_pool: bool = False):
|
||||
|
||||
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
||||
if self.input_patchnorm:
|
||||
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
||||
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
|
||||
x = x.permute(0, 2, 4, 1, 3, 5)
|
||||
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
|
||||
x = self.patchnorm_pre_ln(x)
|
||||
x = self.conv1(x)
|
||||
else:
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
x = torch.cat(
|
||||
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
|
||||
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
||||
x = self.patch_dropout(x)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
if skip_pool:
|
||||
return x
|
||||
|
||||
if self.attn_pool is not None:
|
||||
x = self.attn_pool(x)
|
||||
x = self.ln_post(x)
|
||||
pooled, tokens = self._global_pool(x)
|
||||
else:
|
||||
pooled, tokens = self._global_pool(x)
|
||||
pooled = self.ln_post(pooled)
|
||||
|
||||
if self.proj is not None:
|
||||
pooled = pooled @ self.proj
|
||||
|
||||
if self.output_tokens:
|
||||
return pooled, tokens
|
||||
|
||||
return pooled
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
output_tokens: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int = 77,
|
||||
vocab_size: int = 49408,
|
||||
width: int = 512,
|
||||
heads: int = 8,
|
||||
layers: int = 12,
|
||||
ls_init_value: float = None,
|
||||
output_dim: int = 512,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
embed_cls: bool = False,
|
||||
pad_id: int = 0,
|
||||
output_tokens: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_tokens = output_tokens
|
||||
self.num_pos = self.context_length = context_length
|
||||
self.vocab_size = vocab_size
|
||||
self.width = width
|
||||
self.output_dim = output_dim
|
||||
self.heads = heads
|
||||
self.pad_id = pad_id
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||
|
||||
if embed_cls:
|
||||
self.cls_emb = nn.Parameter(torch.empty(width))
|
||||
self.num_pos += 1
|
||||
else:
|
||||
self.cls_emb = None
|
||||
|
||||
self.token_embedding = nn.Embedding(vocab_size, width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
||||
self.transformer = Transformer(
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.ln_final = norm_layer(width)
|
||||
|
||||
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
||||
|
||||
self.init_parameters()
|
||||
|
||||
def init_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
if self.cls_emb is not None:
|
||||
nn.init.normal_(self.cls_emb, std=0.01)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.transformer.grad_checkpointing = enable
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.num_pos, self.num_pos)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def build_cls_mask(self, text, cast_dtype: torch.dtype):
|
||||
cls_mask = (text != self.pad_id).unsqueeze(1)
|
||||
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
|
||||
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
|
||||
additive_mask.fill_(0)
|
||||
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
||||
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
||||
return additive_mask
|
||||
|
||||
def _repeat(self, t, N: int):
|
||||
return t.reshape(1, 1, -1).repeat(N, 1, 1)
|
||||
|
||||
def forward(self, text):
|
||||
cast_dtype = self.transformer.get_cast_dtype()
|
||||
seq_len = text.shape[1]
|
||||
|
||||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||
attn_mask = self.attn_mask
|
||||
if self.cls_emb is not None:
|
||||
seq_len += 1
|
||||
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
|
||||
cls_mask = self.build_cls_mask(text, cast_dtype)
|
||||
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
|
||||
|
||||
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x, attn_mask=attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
if self.cls_emb is not None:
|
||||
pooled, tokens = x[:, -1], x[:, :-1]
|
||||
pooled = self.ln_final(pooled)
|
||||
else:
|
||||
x = self.ln_final(x)
|
||||
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
||||
|
||||
if self.text_projection is not None:
|
||||
pooled = pooled @ self.text_projection
|
||||
|
||||
if self.output_tokens:
|
||||
return pooled, tokens
|
||||
|
||||
return pooled
|
||||
|
||||
|
||||
class MultimodalTransformer(Transformer):
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
context_length: int = 77,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
output_dim: int = 512,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.context_length = context_length
|
||||
self.cross_attn = nn.ModuleList([
|
||||
ResidualAttentionBlock(
|
||||
width,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
||||
|
||||
self.ln_final = norm_layer(width)
|
||||
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||
|
||||
def init_parameters(self):
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
for block in self.transformer.cross_attn:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def forward(self, image_embs, text_embs):
|
||||
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
|
||||
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
|
||||
seq_len = text_embs.shape[0]
|
||||
|
||||
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
|
||||
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
|
||||
else:
|
||||
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
|
||||
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
|
||||
|
||||
x = text_embs.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x)
|
||||
|
||||
if self.text_projection is not None:
|
||||
x = x @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
60
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
Normal file
60
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from itertools import repeat
|
||||
import collections.abc
|
||||
|
||||
from torch import nn as nn
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
|
||||
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
||||
"""
|
||||
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
||||
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
||||
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Any PyTorch module.
|
||||
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
||||
name (str): Full module name (prefix)
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Resulting module
|
||||
|
||||
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
||||
"""
|
||||
res = module
|
||||
is_match = True
|
||||
if module_match:
|
||||
is_match = name in module_match
|
||||
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
||||
res = FrozenBatchNorm2d(module.num_features)
|
||||
res.num_features = module.num_features
|
||||
res.affine = module.affine
|
||||
if module.affine:
|
||||
res.weight.data = module.weight.data.clone().detach()
|
||||
res.bias.data = module.bias.data.clone().detach()
|
||||
res.running_mean.data = module.running_mean.data
|
||||
res.running_var.data = module.running_var.data
|
||||
res.eps = module.eps
|
||||
else:
|
||||
for child_name, child in module.named_children():
|
||||
full_child_name = '.'.join([name, child_name]) if name else child_name
|
||||
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
||||
if new_child is not child:
|
||||
res.add_module(child_name, new_child)
|
||||
return res
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = lambda n, x: _ntuple(n)(x)
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '2.16.0'
|
||||
112
diffsynth/extensions/ImageQualityMetric/pickscore.py
Normal file
112
diffsynth/extensions/ImageQualityMetric/pickscore.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
from typing import List, Union
|
||||
import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class PickScore(torch.nn.Module):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
super().__init__()
|
||||
"""Initialize the Selector with a processor and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
processor_name_or_path = path.get("clip")
|
||||
model_pretrained_name_or_path = path.get("pickscore")
|
||||
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
||||
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
|
||||
"""Calculate the score for a single image and prompt.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The processed image tensor.
|
||||
prompt (str): The prompt text.
|
||||
softmax (bool): Whether to apply softmax to the scores.
|
||||
|
||||
Returns:
|
||||
float: The score for the image.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Prepare text inputs
|
||||
text_inputs = self.processor(
|
||||
text=prompt,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
# Embed images and text
|
||||
image_embs = self.model.get_image_features(pixel_values=image)
|
||||
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||
text_embs = self.model.get_text_features(**text_inputs)
|
||||
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
||||
|
||||
# Compute score
|
||||
score = (text_embs @ image_embs.T)[0]
|
||||
if softmax:
|
||||
# Apply logit scale and softmax
|
||||
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
|
||||
|
||||
return score.cpu().item()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
softmax (bool): Whether to apply softmax to the scores.
|
||||
|
||||
Returns:
|
||||
List[float]: List of scores for the images.
|
||||
"""
|
||||
try:
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
pil_image = Image.open(images)
|
||||
else:
|
||||
pil_image = images
|
||||
|
||||
# Prepare image inputs
|
||||
image_inputs = self.processor(
|
||||
images=pil_image,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_image in images:
|
||||
if isinstance(one_image, str):
|
||||
pil_image = Image.open(one_image)
|
||||
elif isinstance(one_image, Image.Image):
|
||||
pil_image = one_image
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
|
||||
# Prepare image inputs
|
||||
image_inputs = self.processor(
|
||||
images=pil_image,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in scoring images: {e}")
|
||||
@@ -0,0 +1 @@
|
||||
from .models import *
|
||||
@@ -0,0 +1,3 @@
|
||||
from .base_model import *
|
||||
from .clip_model import *
|
||||
from .cross_modeling import *
|
||||
@@ -0,0 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
pass
|
||||
@@ -0,0 +1,146 @@
|
||||
from dataclasses import dataclass
|
||||
from transformers import CLIPModel as HFCLIPModel
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from torch import nn, einsum
|
||||
|
||||
from .base_model import BaseModelConfig
|
||||
|
||||
from transformers import CLIPConfig
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from .cross_modeling import Cross_model
|
||||
|
||||
import json, os
|
||||
|
||||
class XCLIPModel(HFCLIPModel):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# pooled_output = text_outputs[1]
|
||||
# text_features = self.text_projection(pooled_output)
|
||||
last_hidden_state = text_outputs[0]
|
||||
text_features = self.text_projection(last_hidden_state)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_features_EOS = self.text_projection(pooled_output)
|
||||
|
||||
|
||||
# del last_hidden_state, text_outputs
|
||||
# gc.collect()
|
||||
|
||||
return text_features, text_features_EOS
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# pooled_output = vision_outputs[1] # pooled_output
|
||||
# image_features = self.visual_projection(pooled_output)
|
||||
last_hidden_state = vision_outputs[0]
|
||||
image_features = self.visual_projection(last_hidden_state)
|
||||
|
||||
return image_features
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipModelConfig(BaseModelConfig):
|
||||
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
|
||||
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, ckpt, config_file=False):
|
||||
super().__init__()
|
||||
if config_file is None:
|
||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||
else:
|
||||
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
config = CLIPConfig(**config)
|
||||
self.model = XCLIPModel._from_config(config)
|
||||
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
||||
|
||||
def get_text_features(self, *args, **kwargs):
|
||||
return self.model.get_text_features(*args, **kwargs)
|
||||
|
||||
def get_image_features(self, *args, **kwargs):
|
||||
return self.model.get_image_features(*args, **kwargs)
|
||||
|
||||
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
||||
outputs = ()
|
||||
|
||||
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
||||
outputs += text_EOS,
|
||||
|
||||
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
||||
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
||||
|
||||
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
||||
|
||||
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
||||
bc = int(image_f.shape[0]/2)
|
||||
|
||||
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
||||
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
||||
outputs += sim0[:,0,:],
|
||||
outputs += sim1[:,0,:],
|
||||
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def logit_scale(self):
|
||||
return self.model.logit_scale
|
||||
|
||||
def save(self, path):
|
||||
self.model.save_pretrained(path)
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
import torch
|
||||
from torch import einsum, nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("bias", torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
||||
|
||||
# residual
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return self.fn(x, *args, **kwargs) + x
|
||||
|
||||
|
||||
# rotary positional embedding
|
||||
# https://arxiv.org/abs/2104.09864
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
def forward(self, max_seq_len, *, device):
|
||||
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
||||
return torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||
x1, x2 = x.unbind(dim=-2)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(pos, t):
|
||||
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
||||
|
||||
|
||||
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
||||
# https://arxiv.org/abs/2002.05202
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
|
||||
|
||||
# parallel attention and feedforward with residual
|
||||
# discovered by Wang et al + EleutherAI from GPT-J fame
|
||||
|
||||
class ParallelTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
attn_inner_dim = dim_head * heads
|
||||
ff_inner_dim = dim * ff_mult
|
||||
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
self.rotary_emb = RotaryEmbedding(dim_head)
|
||||
|
||||
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
||||
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
||||
|
||||
self.ff_out = nn.Sequential(
|
||||
SwiGLU(),
|
||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||
)
|
||||
|
||||
self.register_buffer("pos_emb", None, persistent=False)
|
||||
|
||||
|
||||
def get_rotary_embedding(self, n, device):
|
||||
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
||||
return self.pos_emb[:n]
|
||||
|
||||
pos_emb = self.rotary_emb(n, device=device)
|
||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||
return pos_emb
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
n, device, h = x.shape[1], x.device, self.heads
|
||||
|
||||
# pre layernorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# attention queries, keys, values, and feedforward inner
|
||||
|
||||
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
||||
|
||||
# split heads
|
||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||
# https://arxiv.org/abs/1911.02150
|
||||
|
||||
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
positions = self.get_rotary_embedding(n, device)
|
||||
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
|
||||
|
||||
# extra attention mask - for masking out attention from text CLS token to padding
|
||||
|
||||
if exists(attn_mask):
|
||||
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
||||
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.attn_out(out) + self.ff_out(ff)
|
||||
|
||||
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
context_dim=None,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
parallel_ff=False,
|
||||
ff_mult=4,
|
||||
norm_context=False
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = heads * dim_head
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
# whether to have parallel feedforward
|
||||
|
||||
ff_inner_dim = ff_mult * dim
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
||||
SwiGLU(),
|
||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||
) if parallel_ff else None
|
||||
|
||||
def forward(self, x, context, mask):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
# pre-layernorm, for queries and context
|
||||
|
||||
x = self.norm(x)
|
||||
context = self.context_norm(context)
|
||||
|
||||
# get queries
|
||||
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# get key / values
|
||||
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
# query / key similarity
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
# attention
|
||||
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
||||
sim = sim + mask # context mask
|
||||
sim = sim - sim.amax(dim=-1, keepdim=True)
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
# merge and combine heads
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
|
||||
# add parallel feedforward (for multimodal layers)
|
||||
|
||||
if exists(self.ff):
|
||||
out = out + self.ff(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Cross_model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=512,
|
||||
layer_num=4,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
|
||||
for ind in range(layer_num):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
||||
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_tokens,
|
||||
context_tokens,
|
||||
mask
|
||||
):
|
||||
|
||||
for cross_attn, self_attn_ff in self.layers:
|
||||
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
||||
query_tokens = self_attn_ff(query_tokens)
|
||||
|
||||
return query_tokens
|
||||
@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(7+4, c=90)
|
||||
self.block1 = IFBlock(7+4, c=90)
|
||||
@@ -99,7 +99,8 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
return flow_list, mask_list[2], merged
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return IFNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -112,7 +113,7 @@ class IFNetStateDictConverter:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class RIFEInterpolater:
|
||||
@@ -124,7 +125,7 @@ class RIFEInterpolater:
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||
|
||||
def process_image(self, image):
|
||||
width, height = image.size
|
||||
@@ -202,7 +203,7 @@ class RIFESmoother(RIFEInterpolater):
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||
|
||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
||||
output_tensor = []
|
||||
|
||||
0
diffsynth/extensions/__init__.py
Normal file
0
diffsynth/extensions/__init__.py
Normal file
@@ -1,482 +1 @@
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .sd_lora import SDLoRA
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = {}
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_stable_video_diffusion(self, state_dict):
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_RIFE(self, state_dict):
|
||||
param_name = "block_tea.convblock3.0.1.weight"
|
||||
return param_name in state_dict or ("module." + param_name) in state_dict
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||
return param_name in state_dict or param_name_2 in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff_xl(self, state_dict):
|
||||
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_sd_lora(self, state_dict):
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_translator(self, state_dict):
|
||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||
return param_name in state_dict and len(state_dict) == 254
|
||||
|
||||
def is_ipadapter(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
||||
|
||||
def is_ipadapter_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
||||
return param_name in state_dict and len(state_dict) == 521
|
||||
|
||||
def is_ipadapter_xl(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
||||
|
||||
def is_ipadapter_xl_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||
return param_name in state_dict and len(state_dict) == 777
|
||||
|
||||
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
||||
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit(self, state_dict):
|
||||
param_name = "final_layer.adaLN_modulation.1.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_diffusers_vae(self, state_dict):
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||
param_name = "blocks.185.positional_embedding.embeddings"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
"unet": SVDUNet,
|
||||
"vae_decoder": SVDVAEDecoder,
|
||||
"vae_encoder": SVDVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
if component == "unet":
|
||||
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
"unet": SDUNet,
|
||||
"vae_decoder": SDVAEDecoder,
|
||||
"vae_encoder": SDVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
if component == "text_encoder":
|
||||
# Add additional token embeddings to text encoder
|
||||
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||
for keyword in self.textual_inversion_dict:
|
||||
_, embeddings = self.textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDXLTextEncoder,
|
||||
"text_encoder_2": SDXLTextEncoder2,
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component in ["vae_decoder", "vae_encoder"]:
|
||||
# These two model will output nan when float16 is enabled.
|
||||
# The precision problem happens in the last three resnet blocks.
|
||||
# I do not know how to solve this problem.
|
||||
self.model[component].to(torch.float32).to(self.device)
|
||||
else:
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_controlnet(self, state_dict, file_path=""):
|
||||
component = "controlnet"
|
||||
if component not in self.model:
|
||||
self.model[component] = []
|
||||
self.model_path[component] = []
|
||||
model = SDControlNet()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component].append(model)
|
||||
self.model_path[component].append(file_path)
|
||||
|
||||
def load_animatediff(self, state_dict, file_path=""):
|
||||
component = "motion_modules"
|
||||
model = SDMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_animatediff_xl(self, state_dict, file_path=""):
|
||||
component = "motion_modules_xl"
|
||||
model = SDXLMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_RIFE(self, state_dict, file_path=""):
|
||||
component = "RIFE"
|
||||
from ..extensions.RIFE import IFNet
|
||||
model = IFNet().eval()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_sd_lora(self, state_dict, alpha):
|
||||
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||
|
||||
def load_translator(self, state_dict, file_path=""):
|
||||
# This model is lightweight, we do not place it on GPU.
|
||||
component = "translator"
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter(self, state_dict, file_path=""):
|
||||
component = "ipadapter"
|
||||
model = SDIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_image_encoder"
|
||||
model = IpAdapterCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl"
|
||||
model = SDXLIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl_image_encoder"
|
||||
model = IpAdapterXLCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_clip_text_encoder"
|
||||
model = HunyuanDiTCLIPTextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_t5_text_encoder"
|
||||
model = HunyuanDiTT5TextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit"
|
||||
model = HunyuanDiT()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_diffusers_vae(self, state_dict, file_path=""):
|
||||
# TODO: detect SD and SDXL
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||
unet_state_dict = self.model["unet"].state_dict()
|
||||
self.model["unet"].to("cpu")
|
||||
del self.model["unet"]
|
||||
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += self.search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
def load_textual_inversions(self, folder):
|
||||
# Store additional tokens here
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
if file_name.endswith(".txt"):
|
||||
continue
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
|
||||
# Search for embeddings
|
||||
for embeddings in self.search_for_embeddings(state_dict):
|
||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff_xl(state_dict):
|
||||
self.load_animatediff_xl(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
elif self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_sd_lora(state_dict):
|
||||
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
elif self.is_RIFE(state_dict):
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
elif self.is_translator(state_dict):
|
||||
self.load_translator(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter(state_dict):
|
||||
self.load_ipadapter(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_image_encoder(state_dict):
|
||||
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl(state_dict):
|
||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit(state_dict):
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, lora_alphas=lora_alphas)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
if isinstance(self.model[component], list):
|
||||
for model in self.model[component]:
|
||||
model.to(device)
|
||||
else:
|
||||
self.model[component].to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_model_with_model_path(self, model_path):
|
||||
for component in self.model_path:
|
||||
if isinstance(self.model_path[component], str):
|
||||
if os.path.samefile(self.model_path[component], model_path):
|
||||
return self.model[component]
|
||||
elif isinstance(self.model_path[component], list):
|
||||
for i, model_path_ in enumerate(self.model_path[component]):
|
||||
if os.path.samefile(model_path_, model_path):
|
||||
return self.model[component][i]
|
||||
raise ValueError(f"Please load model {model_path} before you use it.")
|
||||
|
||||
def __getattr__(self, __name):
|
||||
if __name in self.model:
|
||||
return self.model[__name]
|
||||
else:
|
||||
return super.__getattribute__(__name)
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
from .model_manager import *
|
||||
|
||||
408
diffsynth/models/cog_dit.py
Normal file
408
diffsynth/models/cog_dit.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .sd3_dit import TimestepEmbeddings
|
||||
from .attention import Attention
|
||||
from .utils import load_state_dict_from_folder
|
||||
from .tiler import TileWorker2Dto3D
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class CogPatchify(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class CogAdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, dim_cond, single=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
||||
|
||||
|
||||
def forward(self, hidden_states, prompt_emb, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
||||
return hidden_states
|
||||
else:
|
||||
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
||||
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
||||
return hidden_states, prompt_emb, gate_a, gate_b
|
||||
|
||||
|
||||
|
||||
class CogDiTBlock(torch.nn.Module):
|
||||
def __init__(self, dim, dim_cond, num_heads):
|
||||
super().__init__()
|
||||
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||
|
||||
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb(self, x, freqs_cis):
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
||||
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
||||
# Attention
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
||||
hidden_states, prompt_emb, time_emb
|
||||
)
|
||||
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
attention_io = self.attn1(
|
||||
attention_io,
|
||||
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
||||
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
||||
|
||||
# Feed forward
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
||||
hidden_states, prompt_emb, time_emb
|
||||
)
|
||||
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_io = self.ff(ff_io)
|
||||
|
||||
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
||||
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
||||
|
||||
return hidden_states, prompt_emb
|
||||
|
||||
|
||||
|
||||
class CogDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.patchify = CogPatchify(16, 3072, 2)
|
||||
self.time_embedder = TimestepEmbeddings(3072, 512)
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
||||
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
||||
|
||||
|
||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
):
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
):
|
||||
grid_height = height // 2
|
||||
grid_width = width // 2
|
||||
base_size_width = 720 // (8 * 2)
|
||||
base_size_height = 480 // (8 * 2)
|
||||
|
||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
||||
embed_dim=64,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def build_mask(self, T, H, W, dtype, device, is_bound):
|
||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else t + 1,
|
||||
pad if is_bound[1] else T - t,
|
||||
pad if is_bound[2] else h + 1,
|
||||
pad if is_bound[3] else H - h,
|
||||
pad if is_bound[4] else w + 1,
|
||||
pad if is_bound[5] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
||||
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
mask = self.build_mask(
|
||||
value.shape[2], (hr-hl), (wr-wl),
|
||||
hidden_states.dtype, hidden_states.device,
|
||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
||||
)
|
||||
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
||||
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
||||
weight[:, :, :, hl:hr, wl:wr] += mask
|
||||
value = value / weight
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||
model_input=hidden_states,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
||||
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
||||
)
|
||||
num_frames, height, width = hidden_states.shape[-3:]
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogDiTStateDictConverter()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
||||
model = CogDiT().to(torch_dtype)
|
||||
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
||||
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class CogDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"patch_embed.proj.weight": "patchify.proj.weight",
|
||||
"patch_embed.proj.bias": "patchify.proj.bias",
|
||||
"patch_embed.text_proj.weight": "context_embedder.weight",
|
||||
"patch_embed.text_proj.bias": "context_embedder.bias",
|
||||
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
|
||||
"norm_final.weight": "norm_final.weight",
|
||||
"norm_final.bias": "norm_final.bias",
|
||||
"norm_out.linear.weight": "norm_out.linear.weight",
|
||||
"norm_out.linear.bias": "norm_out.linear.bias",
|
||||
"norm_out.norm.weight": "norm_out.norm.weight",
|
||||
"norm_out.norm.bias": "norm_out.norm.bias",
|
||||
"proj_out.weight": "proj_out.weight",
|
||||
"proj_out.bias": "proj_out.bias",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.linear.weight": "norm1.linear.weight",
|
||||
"norm1.linear.bias": "norm1.linear.bias",
|
||||
"norm1.norm.weight": "norm1.norm.weight",
|
||||
"norm1.norm.bias": "norm1.norm.bias",
|
||||
"attn1.norm_q.weight": "norm_q.weight",
|
||||
"attn1.norm_q.bias": "norm_q.bias",
|
||||
"attn1.norm_k.weight": "norm_k.weight",
|
||||
"attn1.norm_k.bias": "norm_k.bias",
|
||||
"attn1.to_q.weight": "attn1.to_q.weight",
|
||||
"attn1.to_q.bias": "attn1.to_q.bias",
|
||||
"attn1.to_k.weight": "attn1.to_k.weight",
|
||||
"attn1.to_k.bias": "attn1.to_k.bias",
|
||||
"attn1.to_v.weight": "attn1.to_v.weight",
|
||||
"attn1.to_v.bias": "attn1.to_v.bias",
|
||||
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
||||
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
||||
"norm2.linear.weight": "norm2.linear.weight",
|
||||
"norm2.linear.bias": "norm2.linear.bias",
|
||||
"norm2.norm.weight": "norm2.norm.weight",
|
||||
"norm2.norm.bias": "norm2.norm.bias",
|
||||
"ff.net.0.proj.weight": "ff.0.weight",
|
||||
"ff.net.0.proj.bias": "ff.0.bias",
|
||||
"ff.net.2.weight": "ff.2.weight",
|
||||
"ff.net.2.bias": "ff.2.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "patch_embed.proj.weight":
|
||||
param = param.unsqueeze(2)
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
names = name.split(".")
|
||||
if names[0] == "transformer_blocks":
|
||||
suffix = ".".join(names[2:])
|
||||
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
518
diffsynth/models/cog_vae.py
Normal file
518
diffsynth/models/cog_vae.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .tiler import TileWorker2Dto3D
|
||||
|
||||
|
||||
|
||||
class Downsample3D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 2,
|
||||
padding: int = 0,
|
||||
compress_time: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
# Pad the tensor
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||
x = self.conv(x)
|
||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Upsample3D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: int = 1,
|
||||
compress_time: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
|
||||
x_first = x_first[:, :, None, :, :]
|
||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||
elif inputs.shape[2] > 1:
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
else:
|
||||
inputs = inputs.squeeze(2)
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = self.conv(inputs)
|
||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
|
||||
class CogVideoXSpatialNorm3D(torch.nn.Module):
|
||||
def __init__(self, f_channels, zq_channels, groups):
|
||||
super().__init__()
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
|
||||
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
|
||||
zq = torch.cat([z_first, z_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
|
||||
class Resnet3DBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
|
||||
super().__init__()
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
|
||||
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
|
||||
|
||||
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
|
||||
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
|
||||
if in_channels != out_channels:
|
||||
if use_conv_shortcut:
|
||||
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
else:
|
||||
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
||||
else:
|
||||
self.conv_shortcut = lambda x: x
|
||||
|
||||
|
||||
def forward(self, hidden_states, zq):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class CachedConv3d(torch.nn.Conv3d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.cached_tensor = None
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
self.cached_tensor = None
|
||||
|
||||
|
||||
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
|
||||
if use_cache:
|
||||
if self.cached_tensor is None:
|
||||
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
|
||||
input = torch.concat([self.cached_tensor, input], dim=2)
|
||||
self.cached_tensor = input[:, :, -2:]
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
|
||||
class CogVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.7
|
||||
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Upsample3D(512, 512, compress_time=True),
|
||||
Resnet3DBlock(512, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Upsample3D(256, 256, compress_time=True),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Upsample3D(256, 256, compress_time=False),
|
||||
Resnet3DBlock(256, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
])
|
||||
|
||||
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
sample = sample / self.scaling_factor
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, sample)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||
if tiled:
|
||||
B, C, T, H, W = sample.shape
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.decode_small_video(x),
|
||||
model_input=sample,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
else:
|
||||
return self.decode_small_video(sample)
|
||||
|
||||
|
||||
def decode_small_video(self, sample):
|
||||
B, C, T, H, W = sample.shape
|
||||
computation_device = self.conv_in.weight.device
|
||||
computation_dtype = self.conv_in.weight.dtype
|
||||
value = []
|
||||
for i in range(T//2):
|
||||
tl = i*2 + T%2 - (T%2 and i==0)
|
||||
tr = i*2 + 2 + T%2
|
||||
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||
value.append(model_output)
|
||||
value = torch.concat(value, dim=2)
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CachedConv3d):
|
||||
module.clear_cache()
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class CogVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.7
|
||||
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Downsample3D(128, 128, compress_time=True),
|
||||
Resnet3DBlock(128, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Downsample3D(256, 256, compress_time=True),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Downsample3D(256, 256, compress_time=False),
|
||||
Resnet3DBlock(256, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
])
|
||||
|
||||
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, sample)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)[:, :16]
|
||||
hidden_states = hidden_states * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||
if tiled:
|
||||
B, C, T, H, W = sample.shape
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.encode_small_video(x),
|
||||
model_input=sample,
|
||||
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
|
||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
else:
|
||||
return self.encode_small_video(sample)
|
||||
|
||||
|
||||
def encode_small_video(self, sample):
|
||||
B, C, T, H, W = sample.shape
|
||||
computation_device = self.conv_in.weight.device
|
||||
computation_dtype = self.conv_in.weight.dtype
|
||||
value = []
|
||||
for i in range(T//8):
|
||||
t = i*8 + T%2 - (T%2 and i==0)
|
||||
t_ = i*8 + 8 + T%2
|
||||
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||
value.append(model_output)
|
||||
value = torch.concat(value, dim=2)
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CachedConv3d):
|
||||
module.clear_cache()
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class CogVAEEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.conv.weight": "conv_in.weight",
|
||||
"encoder.conv_in.conv.bias": "conv_in.bias",
|
||||
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
|
||||
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
|
||||
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
|
||||
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
|
||||
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||
"encoder.norm_out.weight": "norm_out.weight",
|
||||
"encoder.norm_out.bias": "norm_out.bias",
|
||||
"encoder.conv_out.conv.weight": "conv_out.weight",
|
||||
"encoder.conv_out.conv.bias": "conv_out.bias",
|
||||
}
|
||||
prefix_dict = {
|
||||
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
|
||||
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
|
||||
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
|
||||
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
|
||||
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
|
||||
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
|
||||
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
|
||||
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
|
||||
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
|
||||
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
|
||||
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
|
||||
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
|
||||
"encoder.mid_block.resnets.0.": "blocks.15.",
|
||||
"encoder.mid_block.resnets.1.": "blocks.16.",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||
"conv1.conv.weight": "conv1.weight",
|
||||
"conv1.conv.bias": "conv1.bias",
|
||||
"conv2.conv.weight": "conv2.weight",
|
||||
"conv2.conv.bias": "conv2.bias",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
"norm1.weight": "norm1.weight",
|
||||
"norm1.bias": "norm1.bias",
|
||||
"norm2.weight": "norm2.weight",
|
||||
"norm2.bias": "norm2.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
for prefix in prefix_dict:
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):]
|
||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
|
||||
class CogVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.conv.weight": "conv_in.weight",
|
||||
"decoder.conv_in.conv.bias": "conv_in.bias",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
|
||||
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
|
||||
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
|
||||
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
|
||||
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
|
||||
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
|
||||
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
|
||||
"decoder.conv_out.conv.weight": "conv_out.weight",
|
||||
"decoder.conv_out.conv.bias": "conv_out.bias"
|
||||
}
|
||||
prefix_dict = {
|
||||
"decoder.mid_block.resnets.0.": "blocks.0.",
|
||||
"decoder.mid_block.resnets.1.": "blocks.1.",
|
||||
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
|
||||
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
|
||||
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
|
||||
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
|
||||
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
|
||||
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
|
||||
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
|
||||
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
|
||||
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
|
||||
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
|
||||
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
|
||||
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
|
||||
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
|
||||
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
|
||||
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
|
||||
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||
"conv1.conv.weight": "conv1.weight",
|
||||
"conv1.conv.bias": "conv1.bias",
|
||||
"conv2.conv.weight": "conv2.weight",
|
||||
"conv2.conv.bias": "conv2.bias",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
for prefix in prefix_dict:
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):]
|
||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
111
diffsynth/models/downloader.py
Normal file
111
diffsynth/models/downloader.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, file_name)
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
"ModelScope",
|
||||
]
|
||||
website_to_preset_models = {
|
||||
"HuggingFace": preset_models_on_huggingface,
|
||||
"ModelScope": preset_models_on_modelscope,
|
||||
}
|
||||
website_to_download_fn = {
|
||||
"HuggingFace": download_from_huggingface,
|
||||
"ModelScope": download_from_modelscope,
|
||||
}
|
||||
|
||||
|
||||
def download_customized_models(
|
||||
model_id,
|
||||
origin_file_path,
|
||||
local_dir,
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
downloaded_files = []
|
||||
for website in downloading_priority:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
load_files = []
|
||||
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
|
||||
# Parse model metadata
|
||||
model_metadata = website_to_preset_models[website][model_id]
|
||||
if isinstance(model_metadata, list):
|
||||
file_data = model_metadata
|
||||
else:
|
||||
file_data = model_metadata.get("file_list", [])
|
||||
|
||||
# Try downloading the model from this website.
|
||||
model_files = []
|
||||
for model_id, origin_file_path, local_dir in file_data:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
model_files.append(file_to_download)
|
||||
|
||||
# If the model is successfully downloaded, break.
|
||||
if len(model_files) > 0:
|
||||
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
||||
model_files = model_metadata["load_path"]
|
||||
load_files.extend(model_files)
|
||||
break
|
||||
|
||||
return load_files
|
||||
329
diffsynth/models/flux_controlnet.py
Normal file
329
diffsynth/models/flux_controlnet.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||
from .utils import hash_state_dict_keys, init_weights_on_device
|
||||
|
||||
|
||||
|
||||
class FluxControlNet(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
||||
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
||||
|
||||
self.mode_dict = mode_dict
|
||||
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
||||
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
||||
if len(res_stack) == 0:
|
||||
return [torch.zeros_like(hidden_states)] * num_blocks
|
||||
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
||||
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
||||
return aligned_res_stack
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
controlnet_conditioning,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
processor_id=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
||||
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
||||
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
||||
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
||||
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
||||
|
||||
controlnet_res_stack = []
|
||||
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_res_stack.append(controlnet_block(hidden_states))
|
||||
|
||||
controlnet_single_res_stack = []
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
||||
|
||||
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
||||
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
||||
|
||||
return controlnet_res_stack, controlnet_single_res_stack
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxControlNetStateDictConverter()
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class QLinear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class QRMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
class QEmbedding(torch.nn.Embedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight= cast_weight(self,input)
|
||||
return torch.nn.functional.embedding(
|
||||
input, weight, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module,quantized_layer.QRMSNorm):
|
||||
continue
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias = module.bias
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
if hasattr(module,"quantized"):
|
||||
continue
|
||||
module.quantized= True
|
||||
new_layer = quantized_layer.QRMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module,torch.nn.Embedding):
|
||||
rows, cols = module.weight.shape
|
||||
new_layer = quantized_layer.QEmbedding(
|
||||
num_embeddings=rows,
|
||||
embedding_dim=cols,
|
||||
_weight=module.weight,
|
||||
# _freeze=module.freeze,
|
||||
padding_idx=module.padding_idx,
|
||||
max_norm=module.max_norm,
|
||||
norm_type=module.norm_type,
|
||||
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||
sparse=module.sparse)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
|
||||
class FluxControlNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
||||
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
||||
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
||||
extra_kwargs = {"num_single_blocks": 0}
|
||||
elif hash_value == "52357cb26250681367488a8954c271e8":
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
742
diffsynth/models/flux_dit.py
Normal file
742
diffsynth/models/flux_dit.py
Normal file
@@ -0,0 +1,742 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RoPEEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class FluxJointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.only_out_a = only_out_a
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
|
||||
def apply_rope(self, xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
qkv_a = self.a_to_qkv(hidden_states_a)
|
||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||
|
||||
# Part B
|
||||
qkv_b = self.b_to_qkv(hidden_states_b)
|
||||
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
||||
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
||||
|
||||
q = torch.concat([q_b, q_a], dim=2)
|
||||
k = torch.concat([k_b, k_a], dim=2)
|
||||
v = torch.concat([v_b, v_a], dim=2)
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class FluxJointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class FluxSingleAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
|
||||
def apply_rope(self, xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def forward(self, hidden_states, image_rotary_emb):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv_a = self.a_to_qkv(hidden_states)
|
||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||
|
||||
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.silu = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
|
||||
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.dim = dim
|
||||
|
||||
self.norm = AdaLayerNormSingle(dim)
|
||||
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
||||
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||
|
||||
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
||||
|
||||
|
||||
def apply_rope(self, xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||
hidden_states_a = residual + hidden_states_a
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.silu = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
||||
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, conditioning):
|
||||
emb = self.linear(self.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
image_start = N * prompt_seq_len
|
||||
image_end = N * prompt_seq_len + image_seq_len
|
||||
# prompt-image mask
|
||||
for i in range(N):
|
||||
prompt_start = i * prompt_seq_len
|
||||
prompt_end = (i + 1) * prompt_seq_len
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt mask
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i != j:
|
||||
prompt_start_i = i * prompt_seq_len
|
||||
prompt_end_i = (i + 1) * prompt_seq_len
|
||||
prompt_start_j = j * prompt_seq_len
|
||||
prompt_end_j = (j + 1) * prompt_seq_len
|
||||
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
if entity_masks is not None:
|
||||
# entity_masks
|
||||
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
# global mask
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
entity_masks = entity_masks + [global_mask] # append global to last
|
||||
# attention mask
|
||||
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
# embds: n_masks * b * seq * d
|
||||
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
prompt_embs = local_embs + prompt_embs # append global to last
|
||||
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||
|
||||
# positional embedding
|
||||
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
return prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
return self.tiled_forward(
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.final_proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias = module.bias
|
||||
# del module
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
if hasattr(module,"quantized"):
|
||||
continue
|
||||
module.quantized= True
|
||||
new_layer = quantized_layer.RMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxDiTStateDictConverter()
|
||||
|
||||
|
||||
class FluxDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"txt_in.bias": "context_embedder.bias",
|
||||
"txt_in.weight": "context_embedder.weight",
|
||||
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
||||
"final_layer.linear.bias": "final_proj_out.bias",
|
||||
"final_layer.linear.weight": "final_proj_out.weight",
|
||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
||||
"img_in.bias": "x_embedder.bias",
|
||||
"img_in.weight": "x_embedder.weight",
|
||||
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
||||
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
||||
"img_attn.proj.bias": "attn.a_to_out.bias",
|
||||
"img_attn.proj.weight": "attn.a_to_out.weight",
|
||||
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
||||
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
||||
"img_mlp.0.bias": "ff_a.0.bias",
|
||||
"img_mlp.0.weight": "ff_a.0.weight",
|
||||
"img_mlp.2.bias": "ff_a.2.bias",
|
||||
"img_mlp.2.weight": "ff_a.2.weight",
|
||||
"img_mod.lin.bias": "norm1_a.linear.bias",
|
||||
"img_mod.lin.weight": "norm1_a.linear.weight",
|
||||
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
||||
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
||||
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
||||
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
||||
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
||||
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
||||
"txt_mlp.0.bias": "ff_b.0.bias",
|
||||
"txt_mlp.0.weight": "ff_b.0.weight",
|
||||
"txt_mlp.2.bias": "ff_b.2.bias",
|
||||
"txt_mlp.2.weight": "ff_b.2.weight",
|
||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
||||
|
||||
"linear1.bias": "to_qkv_mlp.bias",
|
||||
"linear1.weight": "to_qkv_mlp.weight",
|
||||
"linear2.bias": "proj_out.bias",
|
||||
"linear2.weight": "proj_out.weight",
|
||||
"modulation.lin.bias": "norm.linear.bias",
|
||||
"modulation.lin.weight": "norm.linear.weight",
|
||||
"norm.key_norm.scale": "norm_k_a.weight",
|
||||
"norm.query_norm.scale": "norm_q_a.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("model.diffusion_model."):
|
||||
name = name[len("model.diffusion_model."):]
|
||||
names = name.split(".")
|
||||
if name in rename_dict:
|
||||
rename = rename_dict[name]
|
||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "double_blocks":
|
||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "single_blocks":
|
||||
if ".".join(names[2:]) in suffix_rename_dict:
|
||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
else:
|
||||
return state_dict_
|
||||
128
diffsynth/models/flux_infiniteyou.py
Normal file
128
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class InfiniteYouImageProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1280,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=20,
|
||||
num_queries=8,
|
||||
embedding_dim=512,
|
||||
output_dim=4096,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||
|
||||
|
||||
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict['image_proj']
|
||||
94
diffsynth/models/flux_ipadapter.py
Normal file
94
diffsynth/models/flux_ipadapter.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .sd3_dit import RMSNorm
|
||||
from transformers import CLIPImageProcessor
|
||||
import torch
|
||||
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, id_embeds):
|
||||
x = self.proj(id_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class IpAdapterModule(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = attention_head_dim
|
||||
output_dim = num_attention_heads * attention_head_dim
|
||||
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch_size = hidden_states.shape[0]
|
||||
# ip_k
|
||||
ip_k = self.to_k_ip(hidden_states)
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_k = self.norm_added_k(ip_k)
|
||||
# ip_v
|
||||
ip_v = self.to_v_ip(hidden_states)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
return ip_k, ip_v
|
||||
|
||||
|
||||
class FluxIpAdapter(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
||||
super().__init__()
|
||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
||||
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
||||
self.set_adapter()
|
||||
|
||||
def set_adapter(self):
|
||||
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||
ip_kv_dict = {}
|
||||
for block_id in self.call_block_id:
|
||||
ipadapter_id = self.call_block_id[block_id]
|
||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||
ip_kv_dict[block_id] = {
|
||||
"ip_k": ip_k,
|
||||
"ip_v": ip_v,
|
||||
"scale": scale
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
class FluxIpAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict["ip_adapter"]:
|
||||
name_ = 'ipadapter_modules.' + name
|
||||
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||
for name in state_dict["image_proj"]:
|
||||
name_ = "image_proj." + name
|
||||
state_dict_[name_] = state_dict["image_proj"][name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
32
diffsynth/models/flux_text_encoder.py
Normal file
32
diffsynth/models/flux_text_encoder.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2(T5EncoderModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids):
|
||||
outputs = super().forward(input_ids=input_ids)
|
||||
prompt_emb = outputs.last_hidden_state
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2StateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = state_dict
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
303
diffsynth/models/flux_vae.py
Normal file
303
diffsynth/models/flux_vae.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
||||
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class FluxVAEEncoder(SD3VAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEDecoder(SD3VAEDecoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.bias": "conv_in.bias",
|
||||
"encoder.conv_in.weight": "conv_in.weight",
|
||||
"encoder.conv_out.bias": "conv_out.bias",
|
||||
"encoder.conv_out.weight": "conv_out.weight",
|
||||
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
|
||||
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.bias": "conv_in.bias",
|
||||
"decoder.conv_in.weight": "conv_in.weight",
|
||||
"decoder.conv_out.bias": "conv_out.bias",
|
||||
"decoder.conv_out.weight": "conv_out.weight",
|
||||
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,5 +1,4 @@
|
||||
from .attention import Attention
|
||||
from .tiler import TileWorker
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
@@ -399,7 +398,8 @@ class HunyuanDiT(torch.nn.Module):
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -79,7 +79,8 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
@@ -131,7 +132,8 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
920
diffsynth/models/hunyuan_video_dit.py
Normal file
920
diffsynth/models/hunyuan_video_dit.py
Normal file
@@ -0,0 +1,920 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
"""
|
||||
Get n-D meshgrid with start, stop and num.
|
||||
|
||||
Args:
|
||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||
n-tuples.
|
||||
*args: See above.
|
||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
grid (np.ndarray): [dim, ...]
|
||||
"""
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.FloatTensor, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
theta_rescale_factor: float = 1.0,
|
||||
interpolation_factor: float = 1.0,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||
|
||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
) # [D/2]
|
||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(
|
||||
torch.ones_like(freqs), freqs
|
||||
) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
|
||||
Args:
|
||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
*args: See above.
|
||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||
part and an imaginary part separately.
|
||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
assert len(theta_rescale_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
assert len(interpolation_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
# use 1/ndim of dimensions to encode grid_axis
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||
return emb
|
||||
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
[16, 56, 56],
|
||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||
theta=256,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, num_heads=24):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * 4),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size * 4, hidden_size)
|
||||
)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn = rearrange(attn, "B H L D -> B L (H D)")
|
||||
|
||||
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
super().__init__()
|
||||
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.c_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_channels, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
||||
|
||||
def forward(self, x, t, mask=None):
|
||||
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
||||
|
||||
mask_float = mask.float().unsqueeze(-1)
|
||||
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
|
||||
x = self.input_embedder(x)
|
||||
|
||||
mask = mask.to(device=x.device, dtype=torch.bool)
|
||||
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
||||
mask = mask & mask.transpose(2, 3)
|
||||
mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
super().__init__()
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||
if tr_shift is not None:
|
||||
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.concat((x_zero, x_orig), dim=1)
|
||||
return x
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
elif scale is None:
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
x: torch.Tensor,
|
||||
head_first=False,
|
||||
):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = (
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis,
|
||||
head_first: bool = False,
|
||||
):
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
# real * cos - imag * sin
|
||||
# imag * cos + real * sin
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
||||
xq_ = torch.view_as_complex(
|
||||
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
||||
xq.device
|
||||
) # [S, D//2] --> [1, S, 1, D//2]
|
||||
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
||||
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
xk_ = torch.view_as_complex(
|
||||
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = x.transpose(1, 2).flatten(2, 3)
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||
if tr_gate is not None:
|
||||
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||
return torch.concat((x_zero, x_orig), dim=1)
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size)
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||
else:
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
if mod_tr is not None:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||
else:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
||||
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
||||
self.modulation = ModulateDiT(hidden_size, factor=3)
|
||||
|
||||
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
q = torch.cat((q_a, q_b), dim=1)
|
||||
k = torch.cat((k_a, k_b), dim=1)
|
||||
|
||||
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
||||
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size, factor=3)
|
||||
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||
else:
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
v_len = txt_len - split_token
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
||||
|
||||
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.vector_in = torch.nn.Sequential(
|
||||
torch.nn.Linear(768, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
|
||||
# TODO: remove these parameters
|
||||
self.dtype = torch.bfloat16
|
||||
self.patch_size = [1, 2, 2]
|
||||
self.hidden_size = 3072
|
||||
self.heads_num = 24
|
||||
self.rope_dim_list = [16, 56, 56]
|
||||
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
self.to(self.cold_device)
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def prepare_freqs(self, latents):
|
||||
return HunyuanVideoRope(latents)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img = x[:, :-256]
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def block_forward_(self, x, i, j, dtype, device):
|
||||
weight_ = cast_to(
|
||||
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
if self.bias is None or i > 0:
|
||||
bias_ = None
|
||||
else:
|
||||
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
||||
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
||||
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.module.weight is not None:
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(
|
||||
module.in_features, module.out_features, bias=module.bias is not None,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.Conv3d):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Conv3d(
|
||||
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
new_layer = quantized_layer.RMSNorm(
|
||||
module,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.LayerNorm(
|
||||
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module, dtype=dtype, device=device)
|
||||
|
||||
replace_layer(self, dtype=dtype, device=device)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
"img_in.proj": "img_in.proj",
|
||||
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
||||
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
||||
"vector_in.in_layer": "vector_in.0",
|
||||
"vector_in.out_layer": "vector_in.2",
|
||||
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
||||
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
||||
"txt_in.input_embedder": "txt_in.input_embedder",
|
||||
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
||||
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
||||
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
||||
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
||||
"final_layer.linear": "final_layer.linear",
|
||||
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
||||
}
|
||||
txt_suffix_dict = {
|
||||
"norm1": "norm1",
|
||||
"self_attn_qkv": "self_attn_qkv",
|
||||
"self_attn_proj": "self_attn_proj",
|
||||
"norm2": "norm2",
|
||||
"mlp.fc1": "mlp.0",
|
||||
"mlp.fc2": "mlp.2",
|
||||
"adaLN_modulation.1": "adaLN_modulation.1",
|
||||
}
|
||||
double_suffix_dict = {
|
||||
"img_mod.linear": "component_a.mod.linear",
|
||||
"img_attn_qkv": "component_a.to_qkv",
|
||||
"img_attn_q_norm": "component_a.norm_q",
|
||||
"img_attn_k_norm": "component_a.norm_k",
|
||||
"img_attn_proj": "component_a.to_out",
|
||||
"img_mlp.fc1": "component_a.ff.0",
|
||||
"img_mlp.fc2": "component_a.ff.2",
|
||||
"txt_mod.linear": "component_b.mod.linear",
|
||||
"txt_attn_qkv": "component_b.to_qkv",
|
||||
"txt_attn_q_norm": "component_b.norm_q",
|
||||
"txt_attn_k_norm": "component_b.norm_k",
|
||||
"txt_attn_proj": "component_b.to_out",
|
||||
"txt_mlp.fc1": "component_b.ff.0",
|
||||
"txt_mlp.fc2": "component_b.ff.2",
|
||||
}
|
||||
single_suffix_dict = {
|
||||
"linear1": ["to_qkv", "ff.0"],
|
||||
"linear2": ["to_out", "ff.2"],
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"modulation.linear": "mod.linear",
|
||||
}
|
||||
# single_suffix_dict = {
|
||||
# "linear1": "linear1",
|
||||
# "linear2": "linear2",
|
||||
# "q_norm": "q_norm",
|
||||
# "k_norm": "k_norm",
|
||||
# "modulation.linear": "modulation.linear",
|
||||
# }
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
direct_name = ".".join(names[:-1])
|
||||
if direct_name in direct_dict:
|
||||
name_ = direct_dict[direct_name] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "double_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "single_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
if isinstance(single_suffix_dict[suffix], list):
|
||||
if suffix == "linear1":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
elif suffix == "linear2":
|
||||
if names[-1] == "weight":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
else:
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "txt_in":
|
||||
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
||||
suffix = ".".join(names[4:-1])
|
||||
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
68
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
68
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
if self.auto_offload:
|
||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
# TODO: implement the low VRAM inference for MLLM.
|
||||
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||
outputs = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
pixel_values=pixel_values)
|
||||
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
return hidden_state
|
||||
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
@@ -0,0 +1,507 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
||||
) # W, H, T
|
||||
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class UpsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.upsample_factor = upsample_factor
|
||||
self.conv = None
|
||||
if use_conv:
|
||||
kernel_size = 3 if kernel_size is None else kernel_size
|
||||
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# interpolate
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
||||
if T > 1:
|
||||
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
||||
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
||||
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
if self.conv:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
||||
super().__init__()
|
||||
self.pre_norm = True
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
# conv1
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
# conv2
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
# shortcut
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = (self.conv_shortcut(input_tensor))
|
||||
# shortcut and scale
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
||||
seq_len = n_frame * n_hw
|
||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||
for i in range(seq_len):
|
||||
i_frame = i // n_hw
|
||||
mask[i, :(i_frame + 1) * n_hw] = 0
|
||||
if batch_size is not None:
|
||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_groups=32,
|
||||
dropout=0.0,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
residual_connection=True):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.residual_connection = residual_connection
|
||||
dim_inner = head_dim * num_heads
|
||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, input_tensor, attn_mask=None):
|
||||
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(hidden_states)
|
||||
v = self.to_v(hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
||||
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
if self.residual_connection:
|
||||
output_tensor = input_tensor + hidden_states
|
||||
return output_tensor
|
||||
|
||||
|
||||
class UNetMidBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
||||
super().__init__()
|
||||
resnets = [
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
attention_head_dim = attention_head_dim or in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
num_heads=in_channels // attention_head_dim,
|
||||
head_dim=attention_head_dim,
|
||||
num_groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
residual_connection=True,
|
||||
))
|
||||
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
||||
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
||||
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_upsample=True,
|
||||
upsample_scale_factor=(2, 2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.upsamplers = None
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
UpsampleCausal3D(
|
||||
out_channels,
|
||||
use_conv=True,
|
||||
out_channels=out_channels,
|
||||
upsample_factor=upsample_scale_factor,
|
||||
)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DecoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
||||
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
||||
|
||||
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
||||
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
||||
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
||||
|
||||
up_block = UpDecoderBlockCausal3D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
||||
upsample_scale_factor=upsample_scale_factor,
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = DecoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, latents):
|
||||
latents = latents / self.scaling_factor
|
||||
latents = self.post_quant_conv(latents)
|
||||
dec = self.decoder(latents)
|
||||
return dec
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.post_quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.post_quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t * 4 + 1
|
||||
target_h = h * 8
|
||||
target_w = w * 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
307
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
307
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||
|
||||
|
||||
class DownsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
||||
super().__init__()
|
||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_downsample=True,
|
||||
downsample_stride=2,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
||||
out_channels,
|
||||
out_channels,
|
||||
stride=downsample_stride,
|
||||
)])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
||||
|
||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
||||
down_block = DownEncoderBlockCausal3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
||||
downsample_stride=downsample_stride,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = EncoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
latents = latents[:, :16]
|
||||
latents = latents * self.scaling_factor
|
||||
return latents
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t // 4 + 1
|
||||
target_h = h // 8
|
||||
target_w = w // 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
||||
latents = latents.to(self.quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
1551
diffsynth/models/kolors_text_encoder.py
Normal file
1551
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because one or more lines are too long
371
diffsynth/models/lora.py
Normal file
371
diffsynth/models/lora.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNet
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .cog_dit import CogDiT
|
||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||
from .wan_video_dit import WanModel
|
||||
|
||||
|
||||
|
||||
class LoRAFromCivitai:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = []
|
||||
self.lora_prefix = []
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
for key in state_dict:
|
||||
if ".lora_up" in key:
|
||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||
|
||||
|
||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||
for special_key in self.special_keys:
|
||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
target_name = target_name[len(lora_prefix):]
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||
if model_resource == "diffusers":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||
elif model_resource == "civitai":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||
if isinstance(state_dict_lora, tuple):
|
||||
state_dict_lora = state_dict_lora[0]
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
fp8=False
|
||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
||||
fp8=True
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
if fp8:
|
||||
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers", "civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
if isinstance(state_dict_lora_, tuple):
|
||||
state_dict_lora_ = state_dict_lora_[0]
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return lora_prefix, model_resource
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
}
|
||||
|
||||
|
||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||
}
|
||||
|
||||
|
||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [FluxDiT, FluxDiT]
|
||||
self.lora_prefix = ["lora_unet_", "transformer."]
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {
|
||||
"single.blocks": "single_blocks",
|
||||
"double.blocks": "double_blocks",
|
||||
"img.attn": "img_attn",
|
||||
"img.mlp": "img_mlp",
|
||||
"img.mod": "img_mod",
|
||||
"txt.attn": "txt_attn",
|
||||
"txt.mlp": "txt_mlp",
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||
if matched_num == len(lora_name_dict):
|
||||
return "", ""
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_device_and_dtype(self, state_dict):
|
||||
device, dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, dtype = param.device, param.dtype
|
||||
break
|
||||
computation_device = device
|
||||
computation_dtype = dtype
|
||||
if computation_device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
computation_device = torch.device("cuda")
|
||||
if computation_dtype == torch.float8_e4m3fn:
|
||||
computation_dtype = torch.float32
|
||||
return device, dtype, computation_device, computation_dtype
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_patched = weight_model + weight_lora
|
||||
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
class FluxLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||
prefix_rename_dict = {
|
||||
"single_blocks": "lora_unet_single_blocks",
|
||||
"blocks": "lora_unet_double_blocks",
|
||||
}
|
||||
middle_rename_dict = {
|
||||
"norm.linear": "modulation_lin",
|
||||
"to_qkv_mlp": "linear1",
|
||||
"proj_out": "linear2",
|
||||
|
||||
"norm1_a.linear": "img_mod_lin",
|
||||
"norm1_b.linear": "txt_mod_lin",
|
||||
"attn.a_to_qkv": "img_attn_qkv",
|
||||
"attn.b_to_qkv": "txt_attn_qkv",
|
||||
"attn.a_to_out": "img_attn_proj",
|
||||
"attn.b_to_out": "txt_attn_proj",
|
||||
"ff_a.0": "img_mlp_0",
|
||||
"ff_a.2": "img_mlp_2",
|
||||
"ff_b.0": "txt_mlp_0",
|
||||
"ff_b.2": "txt_mlp_2",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"lora_B.weight": "lora_up.weight",
|
||||
"lora_A.weight": "lora_down.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
||||
names.pop(-2)
|
||||
prefix = names[0]
|
||||
middle = ".".join(names[2:-2])
|
||||
suffix = ".".join(names[-2:])
|
||||
block_id = names[1]
|
||||
if middle not in middle_rename_dict:
|
||||
continue
|
||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||
state_dict_[rename] = param
|
||||
if rename.endswith("lora_up.weight"):
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||
return state_dict_
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict):
|
||||
rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
def guess_block_id(name):
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
return None, None
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name)
|
||||
if source_name in rename_dict:
|
||||
target_name = rename_dict[source_name]
|
||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
454
diffsynth/models/model_manager.py
Normal file
454
diffsynth/models/model_manager.py
Normal file
@@ -0,0 +1,454 @@
|
||||
import os, torch, json, importlib
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .lora import get_lora_loaders
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from .sd3_dit import SD3DiT
|
||||
from .sd3_vae_decoder import SD3VAEDecoder
|
||||
from .sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
from .sdxl_controlnet import SDXLControlNetUnion
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder2
|
||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from .flux_ipadapter import FluxIpAdapter
|
||||
|
||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from .cog_dit import CogDiT
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||
state_dict_converter = model_class.state_dict_converter()
|
||||
if model_resource == "civitai":
|
||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||
elif model_resource == "diffusers":
|
||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||
if isinstance(state_dict_results, tuple):
|
||||
model_state_dict, extra_kwargs = state_dict_results
|
||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
with init_weights_on_device():
|
||||
model = model_class(**extra_kwargs)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
try:
|
||||
model = model.to(device=device)
|
||||
except:
|
||||
pass
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||
base_state_dict = base_model.state_dict()
|
||||
base_model.to("cpu")
|
||||
del base_model
|
||||
model = model_class(**extra_kwargs)
|
||||
model.load_state_dict(base_state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
while True:
|
||||
for model_id in range(len(model_manager.model)):
|
||||
base_model_name = model_manager.model_name[model_id]
|
||||
if base_model_name == model_name:
|
||||
base_model_path = model_manager.model_path[model_id]
|
||||
base_model = model_manager.model[model_id]
|
||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||
patched_model = load_single_patch_model_from_single_file(
|
||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||
loaded_model_names.append(base_model_name)
|
||||
loaded_models.append(patched_model)
|
||||
model_manager.model.pop(model_id)
|
||||
model_manager.model_path.pop(model_id)
|
||||
model_manager.model_name.pop(model_id)
|
||||
break
|
||||
else:
|
||||
break
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorTemplate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
return False
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
self.keys_hash_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||
if keys_hash is not None:
|
||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
# Load models without strict matching
|
||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
super().__init__(model_loader_configs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
# Split the state_dict and load from each component
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
valid_state_dict = {}
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
valid_state_dict.update(sub_state_dict)
|
||||
if super().match(file_path, valid_state_dict):
|
||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
else:
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromHuggingfaceFolder:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.architecture_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
return False
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if "architectures" not in config and "_class_name" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
||||
for architecture in architectures:
|
||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
||||
if redirected_architecture is not None:
|
||||
architecture = redirected_architecture
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromPatchedSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
loaded_model_names, loaded_models = [], []
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda",
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
file_path_list: List[str] = [],
|
||||
):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = []
|
||||
self.model_path = []
|
||||
self.model_name = []
|
||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
||||
self.model_detector = [
|
||||
ModelDetectorFromSingleFile(model_loader_configs),
|
||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||
]
|
||||
self.load_models(downloaded_files + file_path_list)
|
||||
|
||||
|
||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||
print(f"Loading models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||
print(f"Loading models from folder: {file_path}")
|
||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||
print(f"Loading patch models from file: {file_path}")
|
||||
model_names, models = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following patched models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
if isinstance(file_path, list):
|
||||
for file_path_ in file_path:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
is_loaded = False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
is_loaded = True
|
||||
break
|
||||
if not is_loaded:
|
||||
print(f" Cannot load LoRA: {file_path}")
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
print(f"Loading models from: {file_path}")
|
||||
if device is None: device = self.device
|
||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for path in file_path:
|
||||
state_dict.update(load_state_dict(path))
|
||||
elif os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
for model_detector in self.model_detector:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
break
|
||||
else:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if file_path is not None and file_path != model_path:
|
||||
continue
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available.")
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
else:
|
||||
return fetched_models[0]
|
||||
|
||||
|
||||
def to(self, device):
|
||||
for model in self.model:
|
||||
model.to(device)
|
||||
|
||||
803
diffsynth/models/omnigen.py
Normal file
803
diffsynth/models/omnigen.py
Normal file
@@ -0,0 +1,803 @@
|
||||
# The code is revised from DiT
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
from safetensors.torch import load_file
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch.utils.checkpoint
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers import Phi3Config, Phi3Model
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Phi3Transformer(Phi3Model):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
||||
We only modified the attention mask
|
||||
Args:
|
||||
config: Phi3Config
|
||||
"""
|
||||
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
||||
"Starts prefetching the next layer cache"
|
||||
with torch.cuda.stream(self.prefetch_stream):
|
||||
# Prefetch next layer tensors to GPU
|
||||
for name, param in self.layers[layer_idx].named_parameters():
|
||||
param.data = param.data.to(device, non_blocking=True)
|
||||
|
||||
def evict_previous_layer(self, layer_idx: int):
|
||||
"Moves the previous layer cache to the CPU"
|
||||
prev_layer_idx = layer_idx - 1
|
||||
for name, param in self.layers[prev_layer_idx].named_parameters():
|
||||
param.data = param.data.to("cpu", non_blocking=True)
|
||||
|
||||
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
||||
# init stream
|
||||
if not hasattr(self, "prefetch_stream"):
|
||||
self.prefetch_stream = torch.cuda.Stream()
|
||||
|
||||
# delete previous layer
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.evict_previous_layer(layer_idx)
|
||||
|
||||
# make sure the current layer is ready
|
||||
torch.cuda.synchronize(self.prefetch_stream)
|
||||
|
||||
# load next layer
|
||||
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
offload_model: Optional[bool] = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
|
||||
# if inputs_embeds is None:
|
||||
# inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# if cache_position is None:
|
||||
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
# cache_position = torch.arange(
|
||||
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
# )
|
||||
# if position_ids is None:
|
||||
# position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 3:
|
||||
dtype = inputs_embeds.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
attention_mask = (1 - attention_mask) * min_dtype
|
||||
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
||||
else:
|
||||
raise Exception("attention_mask parameter was unavailable or invalid")
|
||||
# causal_mask = self._update_causal_mask(
|
||||
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
# )
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
layer_idx = -1
|
||||
for decoder_layer in self.layers:
|
||||
layer_idx += 1
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
if offload_model and not self.training:
|
||||
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
print('************')
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype=torch.float32):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class PatchEmbedMR(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_chans: int = 4,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
class OmniGenOriginalModel(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
transformer_config: Phi3Config,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
pe_interpolation: float = 1.0,
|
||||
pos_embed_max_size: int = 192,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
hidden_size = transformer_config.hidden_size
|
||||
|
||||
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||
|
||||
self.time_token = TimestepEmbedder(hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
||||
self.pe_interpolation = pe_interpolation
|
||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
self.llm = Phi3Transformer(config=transformer_config)
|
||||
self.llm.config.use_cache = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name):
|
||||
if not os.path.exists(model_name):
|
||||
cache_folder = os.getenv('HF_HUB_CACHE')
|
||||
model_name = snapshot_download(repo_id=model_name,
|
||||
cache_dir=cache_folder,
|
||||
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
||||
config = Phi3Config.from_pretrained(model_name)
|
||||
model = cls(config)
|
||||
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
||||
print("Loading safetensors")
|
||||
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
||||
else:
|
||||
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
||||
model.load_state_dict(ckpt)
|
||||
return model
|
||||
|
||||
def initialize_weights(self):
|
||||
assert not hasattr(self, "llama")
|
||||
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
w = self.input_x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
||||
return imgs
|
||||
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
"""Crops positional embeddings for SD3 compatibility."""
|
||||
if self.pos_embed_max_size is None:
|
||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
||||
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
if height > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
if width > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
||||
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
return spatial_pos_embed
|
||||
|
||||
|
||||
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
||||
if isinstance(latents, list):
|
||||
return_list = False
|
||||
if padding_latent is None:
|
||||
padding_latent = [None] * len(latents)
|
||||
return_list = True
|
||||
patched_latents, num_tokens, shapes = [], [], []
|
||||
for latent, padding in zip(latents, padding_latent):
|
||||
height, width = latent.shape[-2:]
|
||||
if is_input_images:
|
||||
latent = self.input_x_embedder(latent)
|
||||
else:
|
||||
latent = self.x_embedder(latent)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
latent = latent + pos_embed
|
||||
if padding is not None:
|
||||
latent = torch.cat([latent, padding], dim=-2)
|
||||
patched_latents.append(latent)
|
||||
|
||||
num_tokens.append(pos_embed.size(1))
|
||||
shapes.append([height, width])
|
||||
if not return_list:
|
||||
latents = torch.cat(patched_latents, dim=0)
|
||||
else:
|
||||
latents = patched_latents
|
||||
else:
|
||||
height, width = latents.shape[-2:]
|
||||
if is_input_images:
|
||||
latents = self.input_x_embedder(latents)
|
||||
else:
|
||||
latents = self.x_embedder(latents)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
latents = latents + pos_embed
|
||||
num_tokens = latents.size(1)
|
||||
shapes = [height, width]
|
||||
return latents, num_tokens, shapes
|
||||
|
||||
|
||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||
"""
|
||||
|
||||
"""
|
||||
input_is_list = isinstance(x, list)
|
||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||
|
||||
if input_img_latents is not None:
|
||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||
if input_ids is not None:
|
||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||
input_img_inx = 0
|
||||
for b_inx in input_image_sizes.keys():
|
||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||
input_img_inx += 1
|
||||
if input_img_latents is not None:
|
||||
assert input_img_inx == len(input_latents)
|
||||
|
||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||
else:
|
||||
input_emb = torch.cat([time_token, x], dim=1)
|
||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||
if input_is_list:
|
||||
image_embedding = output[:, -max(num_tokens):]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = []
|
||||
for i in range(x.size(0)):
|
||||
latent = x[i:i+1, :num_tokens[i]]
|
||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||
latents.append(latent)
|
||||
else:
|
||||
image_embedding = output[:, -num_tokens:]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||
|
||||
if return_past_key_values:
|
||||
return latents, past_key_values
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
||||
if use_img_cfg:
|
||||
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
else:
|
||||
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
|
||||
return torch.cat(model_out, dim=0), past_key_values
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(attention_mask)
|
||||
|
||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||
timestep = timestep.to(x[0].dtype)
|
||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||
|
||||
model_out, pask_key_values = [], []
|
||||
for i in range(len(input_ids)):
|
||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||
model_out.append(temp_out)
|
||||
pask_key_values.append(temp_pask_key_values)
|
||||
|
||||
if len(model_out) == 3:
|
||||
cond, uncond, img_cond = model_out
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
elif len(model_out) == 2:
|
||||
cond, uncond = model_out
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
else:
|
||||
return model_out[0]
|
||||
|
||||
return torch.cat(model_out, dim=0), pask_key_values
|
||||
|
||||
|
||||
|
||||
class OmniGenTransformer(OmniGenOriginalModel):
|
||||
def __init__(self):
|
||||
config = {
|
||||
"_name_or_path": "Phi-3-vision-128k-instruct",
|
||||
"architectures": [
|
||||
"Phi3ForCausalLM"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 8192,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "phi3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"long_factor": [
|
||||
1.0299999713897705,
|
||||
1.0499999523162842,
|
||||
1.0499999523162842,
|
||||
1.0799999237060547,
|
||||
1.2299998998641968,
|
||||
1.2299998998641968,
|
||||
1.2999999523162842,
|
||||
1.4499999284744263,
|
||||
1.5999999046325684,
|
||||
1.6499998569488525,
|
||||
1.8999998569488525,
|
||||
2.859999895095825,
|
||||
3.68999981880188,
|
||||
5.419999599456787,
|
||||
5.489999771118164,
|
||||
5.489999771118164,
|
||||
9.09000015258789,
|
||||
11.579999923706055,
|
||||
15.65999984741211,
|
||||
15.769999504089355,
|
||||
15.789999961853027,
|
||||
18.360000610351562,
|
||||
21.989999771118164,
|
||||
23.079999923706055,
|
||||
30.009998321533203,
|
||||
32.35000228881836,
|
||||
32.590003967285156,
|
||||
35.56000518798828,
|
||||
39.95000457763672,
|
||||
53.840003967285156,
|
||||
56.20000457763672,
|
||||
57.95000457763672,
|
||||
59.29000473022461,
|
||||
59.77000427246094,
|
||||
59.920005798339844,
|
||||
61.190006256103516,
|
||||
61.96000671386719,
|
||||
62.50000762939453,
|
||||
63.3700065612793,
|
||||
63.48000717163086,
|
||||
63.48000717163086,
|
||||
63.66000747680664,
|
||||
63.850006103515625,
|
||||
64.08000946044922,
|
||||
64.760009765625,
|
||||
64.80001068115234,
|
||||
64.81001281738281,
|
||||
64.81001281738281
|
||||
],
|
||||
"short_factor": [
|
||||
1.05,
|
||||
1.05,
|
||||
1.05,
|
||||
1.1,
|
||||
1.1,
|
||||
1.1,
|
||||
1.2500000000000002,
|
||||
1.2500000000000002,
|
||||
1.4000000000000004,
|
||||
1.4500000000000004,
|
||||
1.5500000000000005,
|
||||
1.8500000000000008,
|
||||
1.9000000000000008,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.1000000000000005,
|
||||
2.1000000000000005,
|
||||
2.2,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3999999999999995,
|
||||
2.3999999999999995,
|
||||
2.6499999999999986,
|
||||
2.6999999999999984,
|
||||
2.8999999999999977,
|
||||
2.9499999999999975,
|
||||
3.049999999999997,
|
||||
3.049999999999997,
|
||||
3.049999999999997
|
||||
],
|
||||
"type": "su"
|
||||
},
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 131072,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.38.1",
|
||||
"use_cache": True,
|
||||
"vocab_size": 32064,
|
||||
"_attn_implementation": "sdpa"
|
||||
}
|
||||
config = Phi3Config(**config)
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||
input_is_list = isinstance(x, list)
|
||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||
|
||||
if input_img_latents is not None:
|
||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||
if input_ids is not None:
|
||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||
input_img_inx = 0
|
||||
for b_inx in input_image_sizes.keys():
|
||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||
input_img_inx += 1
|
||||
if input_img_latents is not None:
|
||||
assert input_img_inx == len(input_latents)
|
||||
|
||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||
else:
|
||||
input_emb = torch.cat([time_token, x], dim=1)
|
||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||
if input_is_list:
|
||||
image_embedding = output[:, -max(num_tokens):]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = []
|
||||
for i in range(x.size(0)):
|
||||
latent = x[i:i+1, :num_tokens[i]]
|
||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||
latents.append(latent)
|
||||
else:
|
||||
image_embedding = output[:, -num_tokens:]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||
|
||||
if return_past_key_values:
|
||||
return latents, past_key_values
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(attention_mask)
|
||||
|
||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||
timestep = timestep.to(x[0].dtype)
|
||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||
|
||||
model_out, pask_key_values = [], []
|
||||
for i in range(len(input_ids)):
|
||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||
model_out.append(temp_out)
|
||||
pask_key_values.append(temp_pask_key_values)
|
||||
|
||||
if len(model_out) == 3:
|
||||
cond, uncond, img_cond = model_out
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
elif len(model_out) == 2:
|
||||
cond, uncond = model_out
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
else:
|
||||
return model_out[0]
|
||||
|
||||
return torch.cat(model_out, dim=0), pask_key_values
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return OmniGenTransformerStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class OmniGenTransformerStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
551
diffsynth/models/sd3_dit.py
Normal file
551
diffsynth/models/sd3_dit.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from .svd_unet import TemporalTimesteps
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
||||
return spatial_pos_embed
|
||||
|
||||
def forward(self, latent):
|
||||
height, width = latent.shape[-2:]
|
||||
latent = self.proj(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
return latent + pos_embed
|
||||
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
return time_emb
|
||||
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
|
||||
class JointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.only_out_a = only_out_a
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
if use_rms_norm:
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.norm_q_a = None
|
||||
self.norm_k_a = None
|
||||
self.norm_q_b = None
|
||||
self.norm_k_b = None
|
||||
|
||||
|
||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||
batch_size = hidden_states.shape[0]
|
||||
qkv = to_qkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
if norm_q is not None:
|
||||
q = norm_q(q)
|
||||
if norm_k is not None:
|
||||
k = norm_k(k)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||
qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
|
||||
q = torch.concat([qa, qb], dim=2)
|
||||
k = torch.concat([ka, kb], dim=2)
|
||||
v = torch.concat([va, vb], dim=2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class SingleAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
|
||||
if use_rms_norm:
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.norm_q_a = None
|
||||
self.norm_k_a = None
|
||||
|
||||
|
||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||
batch_size = hidden_states.shape[0]
|
||||
qkv = to_qkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
if norm_q is not None:
|
||||
q = norm_q(q)
|
||||
if norm_k is not None:
|
||||
k = norm_k(k)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states_a):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.a_to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class DualTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim, dual=True)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim, dual=dual)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
if dual:
|
||||
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
if self.norm1_a.dual:
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||
else:
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
if self.norm1_a.dual:
|
||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerFinalBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class SD3DiT(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
|
||||
self.time_embedder = TimestepEmbeddings(256, embed_dim)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
|
||||
self.context_embedder = torch.nn.Linear(4096, embed_dim)
|
||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
|
||||
+ [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
|
||||
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
|
||||
self.norm_out = AdaLayerNorm(embed_dim, single=True)
|
||||
self.proj_out = torch.nn.Linear(embed_dim, 64)
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.pos_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SD3DiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class SD3DiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_architecture(self, state_dict):
|
||||
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
|
||||
num_layers = 100
|
||||
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
|
||||
num_layers -= 1
|
||||
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
|
||||
num_dual_blocks = 0
|
||||
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
|
||||
num_dual_blocks += 1
|
||||
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
|
||||
return {
|
||||
"embed_dim": embed_dim,
|
||||
"num_layers": num_layers,
|
||||
"use_rms_norm": use_rms_norm,
|
||||
"num_dual_blocks": num_dual_blocks,
|
||||
"pos_embed_max_size": pos_embed_max_size
|
||||
}
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
||||
"pos_embed.proj": "pos_embedder.proj",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "norm_out.linear",
|
||||
"proj_out": "proj_out",
|
||||
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "pos_embed.pos_embed":
|
||||
param = param.reshape((1, 192, 192, param.shape[-1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in rename_dict:
|
||||
state_dict_[rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
|
||||
for key in merged_keys:
|
||||
param = torch.concat([
|
||||
state_dict_[key.replace("to_q", "to_q")],
|
||||
state_dict_[key.replace("to_q", "to_k")],
|
||||
state_dict_[key.replace("to_q", "to_v")],
|
||||
], dim=0)
|
||||
name = key.replace("to_q", "to_qkv")
|
||||
state_dict_.pop(key.replace("to_q", "to_q"))
|
||||
state_dict_.pop(key.replace("to_q", "to_k"))
|
||||
state_dict_.pop(key.replace("to_q", "to_v"))
|
||||
state_dict_[name] = param
|
||||
return state_dict_, self.infer_architecture(state_dict_)
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
||||
|
||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
||||
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
||||
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
}
|
||||
for i in range(40):
|
||||
rename_dict.update({
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
|
||||
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
|
||||
})
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "model.diffusion_model.pos_embed":
|
||||
pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
|
||||
param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
||||
state_dict_[name_] = param
|
||||
extra_kwargs = self.infer_architecture(state_dict_)
|
||||
num_layers = extra_kwargs["num_layers"]
|
||||
for name in [
|
||||
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
|
||||
]:
|
||||
param = state_dict_[name]
|
||||
dim = param.shape[0] // 2
|
||||
param = torch.concat([param[dim:], param[:dim]], axis=0)
|
||||
state_dict_[name] = param
|
||||
return state_dict_, self.infer_architecture(state_dict_)
|
||||
1120
diffsynth/models/sd3_text_encoder.py
Normal file
1120
diffsynth/models/sd3_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
81
diffsynth/models/sd3_vae_decoder.py
Normal file
81
diffsynth/models/sd3_vae_decoder.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class SD3VAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
95
diffsynth/models/sd3_vae_encoder.py
Normal file
95
diffsynth/models/sd3_vae_encoder.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SD3VAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states[:, :16]
|
||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
@@ -97,9 +97,10 @@ class SDControlNet(torch.nn.Module):
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
@@ -134,7 +135,8 @@ class SDControlNet(torch.nn.Module):
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDControlNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class SDIpAdapter(torch.nn.Module):
|
||||
|
||||
def set_less_adapter(self):
|
||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
||||
self.set_full_adapter(self)
|
||||
self.set_full_adapter()
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
@@ -47,7 +47,8 @@ class SDIpAdapter(torch.nn.Module):
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
||||
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
||||
|
||||
|
||||
class SDLoRA:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
||||
special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
||||
for special_key in special_keys:
|
||||
target_name = target_name.replace(special_key, special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_unet = unet.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
||||
unet.load_state_dict(state_dict_unet)
|
||||
|
||||
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_text_encoder = text_encoder.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
||||
text_encoder.load_state_dict(state_dict_text_encoder)
|
||||
|
||||
@@ -144,7 +144,8 @@ class SDMotionModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -71,7 +71,8 @@ class SDTextEncoder(torch.nn.Module):
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
|
||||
# 2. pre-process
|
||||
@@ -342,7 +342,8 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -110,10 +112,12 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,8 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -71,6 +73,7 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -91,7 +94,8 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
318
diffsynth/models/sdxl_controlnet.py
Normal file
318
diffsynth/models/sdxl_controlnet.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .tiler import TileWorker
|
||||
from .sd_controlnet import ControlNetConditioningLayer
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
|
||||
class QuickGELU(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
|
||||
class ResidualAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = torch.nn.LayerNorm(d_model)
|
||||
self.mlp = torch.nn.Sequential(OrderedDict([
|
||||
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", torch.nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = torch.nn.LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SDXLControlNetUnion(torch.nn.Module):
|
||||
def __init__(self, global_pool=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(320, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.control_type_proj = Timesteps(256)
|
||||
self.control_type_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(256 * 8, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
|
||||
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
||||
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
|
||||
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
|
||||
self.spatial_ch_projs = torch.nn.Linear(320, 320)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
ResnetBlock(320, 320, 1280),
|
||||
PushBlock(),
|
||||
ResnetBlock(320, 320, 1280),
|
||||
PushBlock(),
|
||||
DownSampler(320),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(320, 640, 1280),
|
||||
AttentionBlock(10, 64, 640, 2, 2048),
|
||||
PushBlock(),
|
||||
ResnetBlock(640, 640, 1280),
|
||||
AttentionBlock(10, 64, 640, 2, 2048),
|
||||
PushBlock(),
|
||||
DownSampler(640),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(640, 1280, 1280),
|
||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||
PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||
PushBlock(),
|
||||
# UNetMidBlock2DCrossAttn
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock()
|
||||
])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
])
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
# 0 -- openpose
|
||||
# 1 -- depth
|
||||
# 2 -- hed/pidi/scribble/ted
|
||||
# 3 -- canny/lineart/anime_lineart/mlsd
|
||||
# 4 -- normal
|
||||
# 5 -- segment
|
||||
# 6 -- tile
|
||||
# 7 -- repaint
|
||||
self.task_id = {
|
||||
"openpose": 0,
|
||||
"depth": 1,
|
||||
"softedge": 2,
|
||||
"canny": 3,
|
||||
"lineart": 3,
|
||||
"lineart_anime": 3,
|
||||
"tile": 6,
|
||||
"inpaint": 7
|
||||
}
|
||||
|
||||
|
||||
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
|
||||
controlnet_cond = self.controlnet_conv_in(conditioning)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[task_id]
|
||||
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
|
||||
x = self.controlnet_transformer(x)
|
||||
|
||||
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
|
||||
controlnet_cond_fuser = controlnet_cond + alpha
|
||||
|
||||
hidden_states = hidden_states + controlnet_cond_fuser
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states,
|
||||
conditioning, processor_id, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
unet:SDXLUNet=None,
|
||||
**kwargs
|
||||
):
|
||||
task_id = self.task_id[processor_id]
|
||||
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(add_time_id)
|
||||
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
||||
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(sample.dtype)
|
||||
if unet is not None and unet.is_kolors:
|
||||
add_embeds = unet.add_time_embedding(add_embeds)
|
||||
else:
|
||||
add_embeds = self.add_time_embedding(add_embeds)
|
||||
|
||||
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
|
||||
control_type[:, task_id] = 1
|
||||
control_embeds = self.control_type_proj(control_type.flatten())
|
||||
control_embeds = control_embeds.reshape((sample.shape[0], -1))
|
||||
control_embeds = control_embeds.to(sample.dtype)
|
||||
control_embeds = self.control_type_embedding(control_embeds)
|
||||
time_emb = t_emb + add_embeds + control_embeds
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
if unet is not None and unet.is_kolors:
|
||||
text_emb = unet.text_intermediate_proj(text_emb)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
# pool
|
||||
if self.global_pool:
|
||||
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLControlNetUnionStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class SDXLControlNetUnionStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
|
||||
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
|
||||
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
|
||||
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
|
||||
]
|
||||
|
||||
# controlnet_rename_dict
|
||||
controlnet_rename_dict = {
|
||||
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
||||
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
||||
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
||||
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
||||
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
||||
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
||||
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
||||
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
||||
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
||||
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
||||
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
||||
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
||||
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
||||
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
||||
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
||||
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
|
||||
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
|
||||
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
|
||||
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
|
||||
}
|
||||
|
||||
# Rename each parameter
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
|
||||
pass
|
||||
elif name in controlnet_rename_dict:
|
||||
names = controlnet_rename_dict[name].split(".")
|
||||
elif names[0] == "controlnet_down_blocks":
|
||||
names[0] = "controlnet_blocks"
|
||||
elif names[0] == "controlnet_mid_block":
|
||||
names = ["controlnet_blocks", "9", names[-1]]
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
||||
elif names[0] == "control_add_embedding":
|
||||
names[0] = "control_type_embedding"
|
||||
elif names[0] == "transformer_layes":
|
||||
names[0] = "controlnet_transformer"
|
||||
names.pop(1)
|
||||
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
||||
if names[0] == "mid_block":
|
||||
names.insert(1, "0")
|
||||
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
names = ["blocks", str(block_id[block_type])] + names[4:]
|
||||
if "ff" in names:
|
||||
ff_index = names.index("ff")
|
||||
component = ".".join(names[ff_index:ff_index+3])
|
||||
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
||||
names = names[:ff_index] + [component] + names[ff_index+3:]
|
||||
if "to_out" in names:
|
||||
names.pop(names.index("to_out") + 1)
|
||||
else:
|
||||
print(name, state_dict[name].shape)
|
||||
# raise ValueError(f"Unknown parameters: {name}")
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name not in rename_dict:
|
||||
continue
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -96,7 +96,8 @@ class SDXLIpAdapter(torch.nn.Module):
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ class SDXLMotionModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,8 @@ class SDXLTextEncoder(torch.nn.Module):
|
||||
break
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
@@ -80,7 +81,8 @@ class SDXLTextEncoder2(torch.nn.Module):
|
||||
pooled_embeds = self.text_projection(pooled_embeds)
|
||||
return pooled_embeds, hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
||||
|
||||
|
||||
class SDXLUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, is_kolors=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
@@ -82,13 +83,17 @@ class SDXLUNet(torch.nn.Module):
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
|
||||
self.is_kolors = is_kolors
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||
tiled=False, tile_size=64, tile_stride=8,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(add_time_id)
|
||||
@@ -102,15 +107,26 @@ class SDXLUNet(torch.nn.Module):
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -119,7 +135,8 @@ class SDXLUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -148,6 +165,8 @@ class SDXLUNetStateDictConverter:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif names[0] in ["encoder_hid_proj"]:
|
||||
names[0] = "text_intermediate_proj"
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
@@ -181,7 +200,10 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
@@ -1873,4 +1895,7 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
@@ -2,14 +2,23 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class SDXLVAEDecoder(SDVAEDecoder):
|
||||
def __init__(self):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
@@ -2,14 +2,23 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SDXLVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
940
diffsynth/models/stepvideo_dit.py
Normal file
940
diffsynth/models/stepvideo_dit.py
Normal file
@@ -0,0 +1,940 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
from typing import Dict, Optional, Tuple, Union, List
|
||||
import torch, math
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": nn.SiLU(),
|
||||
"silu": nn.SiLU(),
|
||||
"mish": nn.Mish(),
|
||||
"gelu": nn.GELU(),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_1 = linear_cls(
|
||||
in_channels,
|
||||
time_embed_dim,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = linear_cls(
|
||||
cond_proj_dim,
|
||||
in_channels,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = linear_cls(
|
||||
time_embed_dim,
|
||||
time_embed_dim_out,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
if self.use_additional_conditions:
|
||||
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
||||
self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, resolution=None, nframe=None, fps=None):
|
||||
hidden_dtype = timestep.dtype
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
if self.use_additional_conditions:
|
||||
batch_size = timestep.shape[0]
|
||||
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
||||
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
||||
nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
|
||||
nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
|
||||
conditioning = timesteps_emb + resolution_emb + nframe_emb
|
||||
|
||||
if fps is not None:
|
||||
fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
|
||||
fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
|
||||
conditioning = conditioning + fps_emb
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
|
||||
|
||||
out = self.linear(self.silu(embedded_timestep))
|
||||
|
||||
return out, embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = nn.Linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def attn_processor(self, attn_type):
|
||||
if attn_type == 'torch':
|
||||
return self.torch_attn_func
|
||||
elif attn_type == 'parallel':
|
||||
return self.parallel_attn_func
|
||||
else:
|
||||
raise Exception('Not supported attention type...')
|
||||
|
||||
def torch_attn_func(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
drop_rate=0.0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
|
||||
if attn_mask is not None and attn_mask.ndim == 3: ## no head
|
||||
n_heads = q.shape[2]
|
||||
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
|
||||
|
||||
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(q.device)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
||||
)
|
||||
x = rearrange(x, 'b h s d -> b s h d')
|
||||
return x
|
||||
|
||||
|
||||
class RoPE1D:
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
self.base = freq
|
||||
self.F0 = F0
|
||||
self.scaling_factor = scaling_factor
|
||||
self.cache = {}
|
||||
|
||||
def get_cos_sin(self, D, seq_len, device, dtype):
|
||||
if (D, seq_len, device, dtype) not in self.cache:
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
||||
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = freqs.cos() # (Seq, Dim)
|
||||
sin = freqs.sin()
|
||||
self.cache[D, seq_len, device, dtype] = (cos, sin)
|
||||
return self.cache[D, seq_len, device, dtype]
|
||||
|
||||
@staticmethod
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
||||
assert pos1d.ndim == 2
|
||||
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
|
||||
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
|
||||
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
||||
|
||||
def __call__(self, tokens, positions):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* positions: batch_size x ntokens (t position of each token)
|
||||
output:
|
||||
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
D = tokens.size(3)
|
||||
assert positions.ndim == 2 # Batch, Seq
|
||||
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
|
||||
tokens = self.apply_rope1d(tokens, positions, cos, sin)
|
||||
return tokens
|
||||
|
||||
|
||||
class RoPE3D(RoPE1D):
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
|
||||
self.position_cache = {}
|
||||
|
||||
def get_mesh_3d(self, rope_positions, bsz):
|
||||
f, h, w = rope_positions
|
||||
|
||||
if f"{f}-{h}-{w}" not in self.position_cache:
|
||||
x = torch.arange(f, device='cpu')
|
||||
y = torch.arange(h, device='cpu')
|
||||
z = torch.arange(w, device='cpu')
|
||||
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
|
||||
return self.position_cache[f"{f}-{h}-{w}"]
|
||||
|
||||
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* rope_positions: list of (f, h, w)
|
||||
output:
|
||||
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
assert sum(ch_split) == tokens.size(-1);
|
||||
|
||||
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
|
||||
out = []
|
||||
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
|
||||
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
|
||||
|
||||
if parallel:
|
||||
pass
|
||||
else:
|
||||
mesh = mesh_grid[:, :, i].clone()
|
||||
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
|
||||
out.append(x)
|
||||
|
||||
tokens = torch.cat(out, dim=-1)
|
||||
return tokens
|
||||
|
||||
|
||||
class SelfAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_rope = with_rope
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
if self.with_rope:
|
||||
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
|
||||
self.rope_ch_split = [64, 32, 32]
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
self.parallel = attn_type=='parallel'
|
||||
|
||||
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
|
||||
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None
|
||||
):
|
||||
xqkv = self.wqkv(x)
|
||||
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
|
||||
|
||||
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
if self.with_rope:
|
||||
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attn_mask=None
|
||||
):
|
||||
xq = self.wq(x)
|
||||
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
|
||||
|
||||
xkv = self.wkv(encoder_hidden_states)
|
||||
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
|
||||
|
||||
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(gate, approximate=self.approximate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: Optional[int] = None,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim*mult if inner_dim is None else inner_dim
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
self.net = nn.ModuleList([
|
||||
GELU(dim, inner_dim, approximate="tanh", bias=bias),
|
||||
nn.Identity(),
|
||||
nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
])
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def modulate(x, scale, shift):
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
def gate(x, gate):
|
||||
x = gate * x
|
||||
return x
|
||||
|
||||
|
||||
class StepVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
attention_head_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = False,
|
||||
attention_type: str = 'parallel'
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
|
||||
|
||||
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
attn_mask = None,
|
||||
rope_positions: list = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
|
||||
|
||||
attn_q = self.attn1(
|
||||
scale_shift_q,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
q = gate(attn_q, gate_msa) + q
|
||||
|
||||
attn_q = self.attn2(
|
||||
q,
|
||||
kv,
|
||||
attn_mask
|
||||
)
|
||||
|
||||
q = attn_q + q
|
||||
|
||||
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
|
||||
|
||||
ff_output = self.ff(scale_shift_q)
|
||||
|
||||
q = gate(ff_output, gate_mlp) + q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=64,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent).to(latent.dtype)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if self.layer_norm:
|
||||
latent = self.norm(latent)
|
||||
|
||||
return latent
|
||||
|
||||
|
||||
class StepVideoModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 48,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 64,
|
||||
num_layers: int = 48,
|
||||
dropout: float = 0.0,
|
||||
patch_size: int = 1,
|
||||
norm_type: str = "ada_norm_single",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
use_additional_conditions: Optional[bool] = False,
|
||||
caption_channels: Optional[Union[int, List, Tuple]] = [6144, 1024],
|
||||
attention_type: Optional[str] = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
StepVideoTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_type=attention_type
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output blocks.
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
||||
)
|
||||
|
||||
if isinstance(caption_channels, int):
|
||||
caption_channel = caption_channels
|
||||
else:
|
||||
caption_channel, clip_channel = caption_channels
|
||||
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
|
||||
|
||||
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channel, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
self.parallel = attention_type=='parallel'
|
||||
|
||||
def patchfy(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
|
||||
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
|
||||
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
|
||||
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
|
||||
for i, kv_len in enumerate(kv_seqlens):
|
||||
mask[i, :, :kv_len] = 1
|
||||
return encoder_hidden_states, mask
|
||||
|
||||
|
||||
def block_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None,
|
||||
parallel=True
|
||||
):
|
||||
for block in tqdm(self.transformer_blocks, desc="Transformer blocks"):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
attn_mask=attn_mask,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: torch.Tensor=None,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
|
||||
|
||||
bsz, frame, _, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
hidden_states = self.patchfy(hidden_states)
|
||||
len_frame = hidden_states.shape[1]
|
||||
|
||||
if self.use_additional_conditions:
|
||||
added_cond_kwargs = {
|
||||
"resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"fps": fps
|
||||
}
|
||||
else:
|
||||
added_cond_kwargs = {}
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
|
||||
|
||||
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
|
||||
clip_embedding = self.clip_projection(encoder_hidden_states_2)
|
||||
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
|
||||
|
||||
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
|
||||
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
|
||||
|
||||
hidden_states = self.block_forward(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
rope_positions=[frame, height, width],
|
||||
attn_mask=attn_mask,
|
||||
parallel=self.parallel
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
|
||||
|
||||
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
|
||||
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
|
||||
|
||||
if return_dict:
|
||||
return {'x': output}
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return StepVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
class StepVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
553
diffsynth/models/stepvideo_text_encoder.py
Normal file
553
diffsynth/models/stepvideo_text_encoder.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .stepvideo_dit import RMSNorm
|
||||
from safetensors.torch import load_file
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from einops import rearrange
|
||||
import json
|
||||
from typing import List
|
||||
from functools import wraps
|
||||
import warnings
|
||||
|
||||
|
||||
|
||||
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if getattr(func, '__module__', None) == 'torch.nn.init':
|
||||
if 'tensor' in kwargs:
|
||||
return kwargs['tensor']
|
||||
else:
|
||||
return args[0]
|
||||
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_empty_init(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with EmptyInitOnDevice('cpu'):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
class LLaMaEmbedding(nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cfg,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = cfg.hidden_size
|
||||
self.params_dtype = cfg.params_dtype
|
||||
self.fp32_residual_connection = cfg.fp32_residual_connection
|
||||
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
|
||||
self.word_embeddings = torch.nn.Embedding(
|
||||
cfg.padded_vocab_size, self.hidden_size,
|
||||
)
|
||||
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
|
||||
|
||||
def forward(self, input_ids):
|
||||
# Embeddings.
|
||||
if self.embedding_weights_in_fp32:
|
||||
self.word_embeddings = self.word_embeddings.to(torch.float32)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.embedding_weights_in_fp32:
|
||||
embeddings = embeddings.to(self.params_dtype)
|
||||
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
|
||||
|
||||
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
embeddings = embeddings.float()
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
class StepChatTokenizer:
|
||||
"""Step Chat Tokenizer"""
|
||||
|
||||
def __init__(
|
||||
self, model_file, name="StepChatTokenizer",
|
||||
bot_token="<|BOT|>", # Begin of Turn
|
||||
eot_token="<|EOT|>", # End of Turn
|
||||
call_start_token="<|CALL_START|>", # Call Start
|
||||
call_end_token="<|CALL_END|>", # Call End
|
||||
think_start_token="<|THINK_START|>", # Think Start
|
||||
think_end_token="<|THINK_END|>", # Think End
|
||||
mask_start_token="<|MASK_1e69f|>", # Mask start
|
||||
mask_end_token="<|UNMASK_1e69f|>", # Mask end
|
||||
):
|
||||
import sentencepiece
|
||||
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
|
||||
|
||||
self._vocab = {}
|
||||
self._inv_vocab = {}
|
||||
|
||||
self._special_tokens = {}
|
||||
self._inv_special_tokens = {}
|
||||
|
||||
self._t5_tokens = []
|
||||
|
||||
for idx in range(self._tokenizer.get_piece_size()):
|
||||
text = self._tokenizer.id_to_piece(idx)
|
||||
self._inv_vocab[idx] = text
|
||||
self._vocab[text] = idx
|
||||
|
||||
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
|
||||
self._special_tokens[text] = idx
|
||||
self._inv_special_tokens[idx] = text
|
||||
|
||||
self._unk_id = self._tokenizer.unk_id()
|
||||
self._bos_id = self._tokenizer.bos_id()
|
||||
self._eos_id = self._tokenizer.eos_id()
|
||||
|
||||
for token in [
|
||||
bot_token, eot_token, call_start_token, call_end_token,
|
||||
think_start_token, think_end_token
|
||||
]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
assert token in self._special_tokens, f"Token '{token}' is not a special token"
|
||||
|
||||
for token in [mask_start_token, mask_end_token]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
|
||||
self._bot_id = self._tokenizer.piece_to_id(bot_token)
|
||||
self._eot_id = self._tokenizer.piece_to_id(eot_token)
|
||||
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
|
||||
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
|
||||
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
|
||||
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
|
||||
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
|
||||
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
|
||||
|
||||
self._underline_id = self._tokenizer.piece_to_id("\u2581")
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self._inv_vocab
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._tokenizer.vocab_size()
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self._tokenizer.encode_as_ids(text)
|
||||
|
||||
def detokenize(self, token_ids: List[int]) -> str:
|
||||
return self._tokenizer.decode_ids(token_ids)
|
||||
|
||||
|
||||
class Tokens:
|
||||
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
self.cu_input_ids = cu_input_ids
|
||||
self.cu_seqlens = cu_seqlens
|
||||
self.max_seq_len = max_seq_len
|
||||
def to(self, device):
|
||||
self.input_ids = self.input_ids.to(device)
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
self.cu_input_ids = self.cu_input_ids.to(device)
|
||||
self.cu_seqlens = self.cu_seqlens.to(device)
|
||||
return self
|
||||
|
||||
class Wrapped_StepChatTokenizer(StepChatTokenizer):
|
||||
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
|
||||
# [bos, ..., eos, pad, pad, ..., pad]
|
||||
self.BOS = 1
|
||||
self.EOS = 2
|
||||
self.PAD = 2
|
||||
out_tokens = []
|
||||
attn_mask = []
|
||||
if len(text) == 0:
|
||||
part_tokens = [self.BOS] + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
else:
|
||||
for part in text:
|
||||
part_tokens = self.tokenize(part)
|
||||
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
|
||||
part_tokens = [self.BOS] + part_tokens + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
|
||||
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
|
||||
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
|
||||
|
||||
# padding y based on tp size
|
||||
padded_len = 0
|
||||
padded_flag = True if padded_len > 0 else False
|
||||
if padded_flag:
|
||||
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
|
||||
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
|
||||
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
|
||||
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
|
||||
|
||||
# cu_seqlens
|
||||
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
|
||||
seqlen = attn_mask.sum(dim=1).tolist()
|
||||
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
|
||||
max_seq_len = max(seqlen)
|
||||
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
|
||||
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
|
||||
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
|
||||
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
|
||||
if hasattr(torch.ops.Optimus, "fwd"):
|
||||
results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
|
||||
else:
|
||||
warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.")
|
||||
results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2)
|
||||
return results
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
|
||||
if cu_seqlens is None:
|
||||
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
|
||||
else:
|
||||
raise ValueError('cu_seqlens is not supported!')
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def safediv(n, d):
|
||||
q, r = divmod(n, d)
|
||||
assert r == 0
|
||||
return q
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
def __init__(self, cfg, layer_id=None):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.max_seq_len = cfg.seq_length
|
||||
self.use_flash_attention = cfg.use_flash_attn
|
||||
assert self.use_flash_attention, 'FlashAttention is required!'
|
||||
|
||||
self.n_groups = cfg.num_attention_groups
|
||||
self.tp_size = 1
|
||||
self.n_local_heads = cfg.num_attention_heads
|
||||
self.n_local_groups = self.n_groups
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
|
||||
bias=False,
|
||||
)
|
||||
self.wo = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
|
||||
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
seqlen, bsz, dim = x.shape
|
||||
xqkv = self.wqkv(x)
|
||||
|
||||
xq, xkv = torch.split(
|
||||
xqkv,
|
||||
(dim // self.tp_size,
|
||||
self.head_dim*2*self.n_groups // self.tp_size
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# gather on 1st dimension
|
||||
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
|
||||
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
|
||||
xk, xv = xkv.chunk(2, -1)
|
||||
|
||||
# rotary embedding + flash attn
|
||||
xq = rearrange(xq, "s b h d -> b s h d")
|
||||
xk = rearrange(xk, "s b h d -> b s h d")
|
||||
xv = rearrange(xv, "s b h d -> b s h d")
|
||||
|
||||
q_per_kv = self.n_local_heads // self.n_local_groups
|
||||
if q_per_kv > 1:
|
||||
b, s, h, d = xk.size()
|
||||
if h == 1:
|
||||
xk = xk.expand(b, s, q_per_kv, d)
|
||||
xv = xv.expand(b, s, q_per_kv, d)
|
||||
else:
|
||||
''' To cover the cases where h > 1, we have
|
||||
the following implementation, which is equivalent to:
|
||||
xk = xk.repeat_interleave(q_per_kv, dim=-2)
|
||||
xv = xv.repeat_interleave(q_per_kv, dim=-2)
|
||||
but can avoid calling aten::item() that involves cpu.
|
||||
'''
|
||||
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
|
||||
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
|
||||
if self.use_flash_attention:
|
||||
output = self.core_attention(xq, xk, xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seq_len=max_seq_len)
|
||||
# reduce-scatter only support first dimension now
|
||||
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
||||
else:
|
||||
xq, xk, xv = [
|
||||
rearrange(x, "b s ... -> s b ...").contiguous()
|
||||
for x in (xq, xk, xv)
|
||||
]
|
||||
output = self.core_attention(xq, xk, xv, mask)
|
||||
output = self.wo(output)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
layer_id: int,
|
||||
multiple_of: int=256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
def swiglu(x):
|
||||
x = torch.chunk(x, 2, dim=-1)
|
||||
return F.silu(x[0]) * x[1]
|
||||
self.swiglu = swiglu
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
dim,
|
||||
2 * hidden_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.swiglu(self.w1(x))
|
||||
output = self.w2(x)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, cfg, layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = cfg.num_attention_heads
|
||||
self.dim = cfg.hidden_size
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.attention = MultiQueryAttention(
|
||||
cfg,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
cfg,
|
||||
dim=cfg.hidden_size,
|
||||
hidden_dim=cfg.ffn_hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
self.ffn_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
residual = self.attention.forward(
|
||||
self.attention_norm(x), mask,
|
||||
cu_seqlens, max_seq_len
|
||||
)
|
||||
h = x + residual
|
||||
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + ffn_res
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_size=8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = config.num_layers
|
||||
self.layers = self._build_layers(config)
|
||||
|
||||
def _build_layers(self, config):
|
||||
layers = torch.nn.ModuleList()
|
||||
for layer_id in range(self.num_layers):
|
||||
layers.append(
|
||||
TransformerBlock(
|
||||
config,
|
||||
layer_id=layer_id + 1 ,
|
||||
)
|
||||
)
|
||||
return layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens=None,
|
||||
max_seq_len=None,
|
||||
):
|
||||
|
||||
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
|
||||
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
|
||||
|
||||
for lid, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens,
|
||||
max_seq_len,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step1Model(PreTrainedModel):
|
||||
config_class=PretrainedConfig
|
||||
@with_empty_init
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = LLaMaEmbedding(config)
|
||||
self.transformer = Transformer(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.transformer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class STEP1TextEncoder(torch.nn.Module):
|
||||
def __init__(self, model_dir, max_length=320):
|
||||
super(STEP1TextEncoder, self).__init__()
|
||||
self.max_length = max_length
|
||||
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
|
||||
text_encoder = Step1Model.from_pretrained(model_dir)
|
||||
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, torch_dtype=torch.bfloat16):
|
||||
model = STEP1TextEncoder(path).to(torch_dtype)
|
||||
return model
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, prompts, with_mask=True, max_length=None, device="cuda"):
|
||||
self.device = device
|
||||
with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
|
||||
if type(prompts) is str:
|
||||
prompts = [prompts]
|
||||
|
||||
txt_tokens = self.text_tokenizer(
|
||||
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
y = self.text_encoder(
|
||||
txt_tokens.input_ids.to(self.device),
|
||||
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
|
||||
)
|
||||
y_mask = txt_tokens.attention_mask
|
||||
return y.transpose(0,1), y_mask
|
||||
|
||||
1132
diffsynth/models/stepvideo_vae.py
Normal file
1132
diffsynth/models/stepvideo_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user