mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
Compare commits
237 Commits
wanx_dev1
...
value-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba421a9ab9 | ||
|
|
6c30a7f080 | ||
|
|
1363a0559f | ||
|
|
9bb51fe879 | ||
|
|
d9c812818d | ||
|
|
c8e9a96196 | ||
|
|
6143af4654 | ||
|
|
9458e382b0 | ||
|
|
4f2d9226cf | ||
|
|
f688a469b1 | ||
|
|
c8ea3b3356 | ||
|
|
6e9472b470 | ||
|
|
a5c03c5272 | ||
|
|
8068ac2592 | ||
|
|
5f80e7ac5e | ||
|
|
157e0be49d | ||
|
|
3dbe271aab | ||
|
|
44e2eecdf1 | ||
|
|
8c226e83a6 | ||
|
|
009f26bb40 | ||
|
|
fcf2fbc07f | ||
|
|
b603acd36a | ||
|
|
6c8bb6438b | ||
|
|
8072d3839d | ||
|
|
c8ad643374 | ||
|
|
31f9df5e62 | ||
|
|
e2f415524a | ||
|
|
3eb7e7530e | ||
|
|
916aa54595 | ||
|
|
6ddbd43f7b | ||
|
|
a37a83ecc3 | ||
|
|
f2a0d0c85f | ||
|
|
93194f44e8 | ||
|
|
c4e5033532 | ||
|
|
cc6cd26733 | ||
|
|
1113d305d1 | ||
|
|
6d5f8b7423 | ||
|
|
1b3c204d20 | ||
|
|
1788d50f0a | ||
|
|
e7a21dbf0b | ||
|
|
3b3e1e4d44 | ||
|
|
24426e3a32 | ||
|
|
31369bab15 | ||
|
|
551721658b | ||
|
|
46f052375f | ||
|
|
c2d35a2157 | ||
|
|
4c052e42bc | ||
|
|
a88613555d | ||
|
|
c164519ef1 | ||
|
|
afff5ffb21 | ||
|
|
a8481fd5e1 | ||
|
|
8584e50309 | ||
|
|
9f3e02f167 | ||
|
|
7ad9b9aecc | ||
|
|
b6a111d3a2 | ||
|
|
bd6f2695a9 | ||
|
|
6eecc9d442 | ||
|
|
35269783d7 | ||
|
|
9534a78167 | ||
|
|
830b1b7202 | ||
|
|
436a91e0c9 | ||
|
|
40760ab88b | ||
|
|
8badd63a2d | ||
|
|
b1afff1728 | ||
|
|
6e977e1181 | ||
|
|
62f6ca2b8a | ||
|
|
4e00c109e3 | ||
|
|
8f10a9c353 | ||
|
|
a3a35acc7e | ||
|
|
675eefa07e | ||
|
|
dbef6122e9 | ||
|
|
d150bcf622 | ||
|
|
451aab0116 | ||
|
|
3edf3583b1 | ||
|
|
ef2a7abad4 | ||
|
|
32f630ff5f | ||
|
|
109a0a0d49 | ||
|
|
4f01b37a2a | ||
|
|
cc6306136c | ||
|
|
419ace37f3 | ||
|
|
ccf24c363f | ||
|
|
b7a1ac6671 | ||
|
|
e54c0a8468 | ||
|
|
5f4cb32255 | ||
|
|
7b6cf39618 | ||
|
|
bf81de0c88 | ||
|
|
b36cad6929 | ||
|
|
b161bd6dfd | ||
|
|
538cfcbb77 | ||
|
|
a4105d2c0e | ||
|
|
553b341f5f | ||
|
|
e9e24b8cf1 | ||
|
|
1b693d0028 | ||
|
|
a4c3c07229 | ||
|
|
6b24748c80 | ||
|
|
8f2f8646eb | ||
|
|
e3ac438f5a | ||
|
|
b731628112 | ||
|
|
0dc56d9dcc | ||
|
|
b925b402e2 | ||
|
|
61d9653536 | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
a66203a391 | ||
|
|
fab61f614b | ||
|
|
6b67a11ad6 | ||
|
|
91f77d268c | ||
|
|
eb4d5187d8 | ||
|
|
ee4b02247c | ||
|
|
da8e1fe7e4 | ||
|
|
3db824c281 | ||
|
|
df2ecafd3f | ||
|
|
217652d28e | ||
|
|
f64c766dcd | ||
|
|
076fd85556 | ||
|
|
c7912ed827 | ||
|
|
e63f9d6993 | ||
|
|
d80ef3a677 | ||
|
|
852c3d831f | ||
|
|
ceb92ee7aa | ||
|
|
3a75026176 | ||
|
|
6a92b08244 | ||
|
|
38bc785ea9 | ||
|
|
a466fdca8f | ||
|
|
f9f49e3c78 | ||
|
|
61a30673c2 | ||
|
|
a48822ec00 | ||
|
|
b6c3d2b74a | ||
|
|
5006c2176c | ||
|
|
d3d3556ff6 | ||
|
|
6fa8dbe077 | ||
|
|
a57749ef60 | ||
|
|
b5c1d33e58 | ||
|
|
34a9f82865 | ||
|
|
18dc6cb962 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
c760208614 | ||
|
|
fad7aea58a | ||
|
|
b42eb1444c | ||
|
|
25a247dd3f | ||
|
|
7792017a02 | ||
|
|
0219e8d2f3 | ||
|
|
1d309a14a3 | ||
|
|
7df73ceaaf | ||
|
|
0dbb3d333f | ||
|
|
1419bec53d | ||
|
|
cf12723c89 | ||
|
|
4268f5466b | ||
|
|
b9f5a00d98 | ||
|
|
7d44dc99fb | ||
|
|
b20de1b44d | ||
|
|
366ee0f542 | ||
|
|
bed770248b | ||
|
|
020560d2b5 | ||
|
|
af7d305f00 | ||
|
|
4449faaa01 | ||
|
|
991ba162bd | ||
|
|
77d0f4d297 | ||
|
|
a834371d50 | ||
|
|
acda7d891a |
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel && pip install -r requirements.txt
|
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||||
- name: Build DiffSynth
|
- name: Build DiffSynth
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python setup.py sdist bdist_wheel
|
||||||
- name: Publish package to PyPI
|
- name: Publish package to PyPI
|
||||||
|
|||||||
37
README.md
37
README.md
@@ -13,12 +13,19 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
|||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
Welcome to the magic world of Diffusion models!
|
||||||
|
|
||||||
Until now, DiffSynth Studio has supported the following models:
|
DiffSynth consists of two open-source projects:
|
||||||
|
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||||
|
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||||
|
|
||||||
|
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||||
|
|
||||||
|
Until now, DiffSynth-Studio has supported the following models:
|
||||||
|
|
||||||
|
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
@@ -35,12 +42,21 @@ Until now, DiffSynth Studio has supported the following models:
|
|||||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
|
||||||
|
|
||||||
- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||||
|
|
||||||
|
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||||
|
|
||||||
|
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||||
|
|
||||||
|
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||||
|
|
||||||
|
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||||
|
|
||||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
@@ -69,7 +85,7 @@ Until now, DiffSynth Studio has supported the following models:
|
|||||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||||
- LoRA, ControlNet, and additional models will be available soon.
|
- LoRA, ControlNet, and additional models will be available soon.
|
||||||
|
|
||||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||||
@@ -118,12 +134,19 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
Or install from pypi:
|
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install diffsynth
|
pip install diffsynth
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you encounter issues during installation, it may be caused by the packages we depend on. Please refer to the documentation of the package that caused the problem.
|
||||||
|
|
||||||
|
* [torch](https://pytorch.org/get-started/locally/)
|
||||||
|
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||||
|
* [cmake](https://cmake.org)
|
||||||
|
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||||
|
|
||||||
## Usage (in Python code)
|
## Usage (in Python code)
|
||||||
|
|
||||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Set web page format
|
# Set web page format
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
st.set_page_config(layout="wide")
|
st.set_page_config(layout="wide")
|
||||||
# Diasble virtual VRAM on windows system
|
# Disable virtual VRAM on windows system
|
||||||
import torch
|
import torch
|
||||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from ..models.flux_text_encoder import FluxTextEncoder2
|
|||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
from ..models.flux_controlnet import FluxControlNet
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
from ..models.flux_ipadapter import FluxIpAdapter
|
||||||
|
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||||
|
|
||||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||||
from ..models.cog_dit import CogDiT
|
from ..models.cog_dit import CogDiT
|
||||||
@@ -54,7 +55,17 @@ from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
|||||||
from ..models.stepvideo_vae import StepVideoVAE
|
from ..models.stepvideo_vae import StepVideoVAE
|
||||||
from ..models.stepvideo_dit import StepVideoModel
|
from ..models.stepvideo_dit import StepVideoModel
|
||||||
|
|
||||||
from ..models.wanx_vae import WanXVideoVAE
|
from ..models.wan_video_dit import WanModel
|
||||||
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
|
|
||||||
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
|
from ..models.flux_value_control import SingleValueEncoder
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -91,6 +102,9 @@ model_loader_configs = [
|
|||||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||||
|
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||||
@@ -99,6 +113,9 @@ model_loader_configs = [
|
|||||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "43ad5aaa27dd4ee01b832ed16773fa52", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||||
@@ -109,7 +126,27 @@ model_loader_configs = [
|
|||||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||||
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
||||||
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wanxvideo_vae"], [WanXVideoVAE], "civitai")
|
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
|
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||||
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
|
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||||
|
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||||
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
|
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
|
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||||
|
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -123,7 +160,9 @@ huggingface_model_loader_configs = [
|
|||||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||||
|
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||||
|
("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
|
||||||
]
|
]
|
||||||
patch_model_loader_configs = [
|
patch_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -585,6 +624,25 @@ preset_models_on_modelscope = {
|
|||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"InfiniteYou":{
|
||||||
|
"file_list":[
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
],
|
||||||
|
"load_path":[
|
||||||
|
[
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||||
|
],
|
||||||
|
"models/InfiniteYou/image_proj_model.bin",
|
||||||
|
],
|
||||||
|
},
|
||||||
# ESRGAN
|
# ESRGAN
|
||||||
"ESRGAN_x4": [
|
"ESRGAN_x4": [
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||||
@@ -665,6 +723,25 @@ preset_models_on_modelscope = {
|
|||||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"HunyuanVideoI2V":{
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||||
|
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
},
|
||||||
"HunyuanVideo-fp8":{
|
"HunyuanVideo-fp8":{
|
||||||
"file_list": [
|
"file_list": [
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||||
@@ -725,6 +802,7 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||||
|
"InfiniteYou",
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
"QwenPrompt",
|
"QwenPrompt",
|
||||||
"OmostPrompt",
|
"OmostPrompt",
|
||||||
@@ -741,4 +819,5 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"StableDiffusion3.5-medium",
|
"StableDiffusion3.5-medium",
|
||||||
"HunyuanVideo",
|
"HunyuanVideo",
|
||||||
"HunyuanVideo-fp8",
|
"HunyuanVideo-fp8",
|
||||||
|
"HunyuanVideoI2V",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,10 +1,4 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
from controlnet_aux.processor import (
|
|
||||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Processor_id: TypeAlias = Literal[
|
Processor_id: TypeAlias = Literal[
|
||||||
@@ -15,18 +9,25 @@ class Annotator:
|
|||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||||
if not skip_processor:
|
if not skip_processor:
|
||||||
if processor_id == "canny":
|
if processor_id == "canny":
|
||||||
|
from controlnet_aux.processor import CannyDetector
|
||||||
self.processor = CannyDetector()
|
self.processor = CannyDetector()
|
||||||
elif processor_id == "depth":
|
elif processor_id == "depth":
|
||||||
|
from controlnet_aux.processor import MidasDetector
|
||||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "softedge":
|
elif processor_id == "softedge":
|
||||||
|
from controlnet_aux.processor import HEDdetector
|
||||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "lineart":
|
elif processor_id == "lineart":
|
||||||
|
from controlnet_aux.processor import LineartDetector
|
||||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "lineart_anime":
|
elif processor_id == "lineart_anime":
|
||||||
|
from controlnet_aux.processor import LineartAnimeDetector
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "openpose":
|
elif processor_id == "openpose":
|
||||||
|
from controlnet_aux.processor import OpenposeDetector
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "normal":
|
elif processor_id == "normal":
|
||||||
|
from controlnet_aux.processor import NormalBaeDetector
|
||||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|||||||
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||||
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
return x.to(position.dtype)
|
||||||
|
|
||||||
|
def pad_freqs(original_tensor, target_len):
|
||||||
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
|
pad_size = target_len - seq_len
|
||||||
|
padding_tensor = torch.ones(
|
||||||
|
pad_size,
|
||||||
|
s1,
|
||||||
|
s2,
|
||||||
|
dtype=original_tensor.dtype,
|
||||||
|
device=original_tensor.device)
|
||||||
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def rope_apply(x, freqs, num_heads):
|
||||||
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
s_per_rank = x.shape[1]
|
||||||
|
|
||||||
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
|
|
||||||
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
def usp_dit_forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = self.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = self.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = torch.chunk(
|
||||||
|
x, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
x = self.head(x, t)
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_attn_forward(self, x, freqs):
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(x))
|
||||||
|
v = self.v(x)
|
||||||
|
|
||||||
|
q = rope_apply(q, freqs, self.num_heads)
|
||||||
|
k = rope_apply(k, freqs, self.num_heads)
|
||||||
|
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
|
||||||
|
x = xFuserLongContextAttention()(
|
||||||
|
None,
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
)
|
||||||
|
x = x.flatten(2)
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return self.o(x)
|
||||||
1
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
Normal file
1
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .blip_pretrain import *
|
||||||
77
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
Normal file
77
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
'''
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from timm.models.hub import download_cached_file
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
from .vit import VisionTransformer, interpolate_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def default_bert():
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||||
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
return os.path.join(model_path, "bert-base-uncased")
|
||||||
|
|
||||||
|
|
||||||
|
def init_tokenizer(bert_model_path):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||||
|
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||||
|
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||||
|
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
||||||
|
|
||||||
|
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
||||||
|
if vit=='base':
|
||||||
|
vision_width = 768
|
||||||
|
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
||||||
|
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||||
|
drop_path_rate=0 or drop_path_rate
|
||||||
|
)
|
||||||
|
elif vit=='large':
|
||||||
|
vision_width = 1024
|
||||||
|
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
||||||
|
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||||
|
drop_path_rate=0.1 or drop_path_rate
|
||||||
|
)
|
||||||
|
return visual_encoder, vision_width
|
||||||
|
|
||||||
|
|
||||||
|
def is_url(url_or_filename):
|
||||||
|
parsed = urlparse(url_or_filename)
|
||||||
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
def load_checkpoint(model,url_or_filename):
|
||||||
|
if is_url(url_or_filename):
|
||||||
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||||
|
checkpoint = torch.load(cached_file, map_location='cpu')
|
||||||
|
elif os.path.isfile(url_or_filename):
|
||||||
|
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
||||||
|
else:
|
||||||
|
raise RuntimeError('checkpoint url or path is invalid')
|
||||||
|
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
|
||||||
|
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
||||||
|
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
||||||
|
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
||||||
|
model.visual_encoder_m)
|
||||||
|
for key in model.state_dict().keys():
|
||||||
|
if key in state_dict.keys():
|
||||||
|
if state_dict[key].shape!=model.state_dict()[key].shape:
|
||||||
|
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
msg = model.load_state_dict(state_dict,strict=False)
|
||||||
|
print('load checkpoint from %s'%url_or_filename)
|
||||||
|
return model,msg
|
||||||
|
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
'''
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import os
|
||||||
|
from .med import BertConfig, BertModel
|
||||||
|
from .blip import create_vit, init_tokenizer
|
||||||
|
|
||||||
|
class BLIP_Pretrain(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
med_config = "med_config.json",
|
||||||
|
image_size = 224,
|
||||||
|
vit = 'base',
|
||||||
|
vit_grad_ckpt = False,
|
||||||
|
vit_ckpt_layer = 0,
|
||||||
|
embed_dim = 256,
|
||||||
|
queue_size = 57600,
|
||||||
|
momentum = 0.995,
|
||||||
|
bert_model_path = ""
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||||
|
image_size (int): input image size
|
||||||
|
vit (str): model size of vision transformer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
||||||
|
|
||||||
|
self.tokenizer = init_tokenizer(bert_model_path)
|
||||||
|
encoder_config = BertConfig.from_json_file(med_config)
|
||||||
|
encoder_config.encoder_width = vision_width
|
||||||
|
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
||||||
|
|
||||||
|
text_width = self.text_encoder.config.hidden_size
|
||||||
|
|
||||||
|
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
||||||
|
self.text_proj = nn.Linear(text_width, embed_dim)
|
||||||
|
|
||||||
947
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
Normal file
947
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
Normal file
@@ -0,0 +1,947 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
* Based on huggingface code base
|
||||||
|
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
||||||
|
'''
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, device, nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.file_utils import (
|
||||||
|
ModelOutput,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
MaskedLMOutput,
|
||||||
|
MultipleChoiceModelOutput,
|
||||||
|
NextSentencePredictorOutput,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutput,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.models.bert.configuration_bert import BertConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word and position embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
|
||||||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||||
|
# any TensorFlow checkpoint file
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
|
if input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
else:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
embeddings = inputs_embeds
|
||||||
|
|
||||||
|
if self.position_embedding_type == "absolute":
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
embeddings += position_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config, is_cross_attention):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
if is_cross_attention:
|
||||||
|
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
||||||
|
else:
|
||||||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
self.save_attention = False
|
||||||
|
|
||||||
|
def save_attn_gradients(self, attn_gradients):
|
||||||
|
self.attn_gradients = attn_gradients
|
||||||
|
|
||||||
|
def get_attn_gradients(self):
|
||||||
|
return self.attn_gradients
|
||||||
|
|
||||||
|
def save_attention_map(self, attention_map):
|
||||||
|
self.attention_map = attention_map
|
||||||
|
|
||||||
|
def get_attention_map(self):
|
||||||
|
return self.attention_map
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
seq_length = hidden_states.size()[1]
|
||||||
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||||
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||||
|
distance = position_ids_l - position_ids_r
|
||||||
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||||
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key":
|
||||||
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores
|
||||||
|
elif self.position_embedding_type == "relative_key_query":
|
||||||
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||||
|
|
||||||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
|
|
||||||
|
if is_cross_attention and self.save_attention:
|
||||||
|
self.save_attention_map(attention_probs)
|
||||||
|
attention_probs.register_hook(self.save_attn_gradients)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs_dropped = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs_dropped = attention_probs_dropped * head_mask
|
||||||
|
|
||||||
|
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertAttention(nn.Module):
|
||||||
|
def __init__(self, config, is_cross_attention=False):
|
||||||
|
super().__init__()
|
||||||
|
self.self = BertSelfAttention(config, is_cross_attention)
|
||||||
|
self.output = BertSelfOutput(config)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prune linear layers
|
||||||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
|
|
||||||
|
# Update hyper params and store pruned heads
|
||||||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
self_outputs = self.self(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertIntermediate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_num):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
self.attention = BertAttention(config)
|
||||||
|
self.layer_num = layer_num
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
||||||
|
self.intermediate = BertIntermediate(config)
|
||||||
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
mode=None,
|
||||||
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
|
||||||
|
if mode=='multimodal':
|
||||||
|
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
||||||
|
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
attention_output,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
layer_output = apply_chunking_to_forward(
|
||||||
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
|
)
|
||||||
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def feed_forward_chunk(self, attention_output):
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for i in range(self.config.num_hidden_layers):
|
||||||
|
layer_module = self.layer[i]
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(layer_module),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BertPooler(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
|
# to the first token.
|
||||||
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertPredictionHeadTransform(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.transform_act_fn = config.hidden_act
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLMPredictionHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
|
|
||||||
|
# The output weights are the same as the input embeddings, but there is
|
||||||
|
# an output-only bias for each token.
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.transform(hidden_states)
|
||||||
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOnlyMLMHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.predictions = BertLMPredictionHead(config)
|
||||||
|
|
||||||
|
def forward(self, sequence_output):
|
||||||
|
prediction_scores = self.predictions(sequence_output)
|
||||||
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
|
class BertPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = BertConfig
|
||||||
|
base_model_prefix = "bert"
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
""" Initialize the weights """
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class BertModel(BertPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||||
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||||
|
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||||
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
||||||
|
input to the forward pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
|
||||||
|
self.encoder = BertEncoder(config)
|
||||||
|
|
||||||
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
"""
|
||||||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||||
|
class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
|
||||||
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
||||||
|
"""
|
||||||
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
attention_mask (:obj:`torch.Tensor`):
|
||||||
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||||
|
input_shape (:obj:`Tuple[int]`):
|
||||||
|
The shape of the input to the model.
|
||||||
|
device: (:obj:`torch.device`):
|
||||||
|
The device of the input to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
||||||
|
"""
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
if attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if is_decoder:
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||||
|
# causal and attention masks must have same type with pytorch version < 1.3
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
|
||||||
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||||
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||||
|
causal_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
||||||
|
causal_mask,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||||
|
input_shape, attention_mask.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
return extended_attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
is_decoder=False,
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = input_ids.device
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
elif encoder_embeds is not None:
|
||||||
|
input_shape = encoder_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = encoder_embeds.device
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
||||||
|
device, is_decoder)
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if type(encoder_hidden_states) == list:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||||
|
else:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
|
||||||
|
if type(encoder_attention_mask) == list:
|
||||||
|
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||||
|
elif encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
if encoder_embeds is None:
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_output = encoder_embeds
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
|
self.cls = BertOnlyMLMHead(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
return_logits=False,
|
||||||
|
is_decoder=True,
|
||||||
|
reduction='mean',
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
Returns:
|
||||||
|
Example::
|
||||||
|
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
||||||
|
>>> import torch
|
||||||
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
|
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
||||||
|
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> prediction_logits = outputs.logits
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
is_decoder=is_decoder,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.cls(sequence_output)
|
||||||
|
|
||||||
|
if return_logits:
|
||||||
|
return prediction_scores[:, :-1, :].contiguous()
|
||||||
|
|
||||||
|
lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
||||||
|
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
if reduction=='none':
|
||||||
|
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (prediction_scores,) + outputs[2:]
|
||||||
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=lm_loss,
|
||||||
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
input_shape = input_ids.shape
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
||||||
|
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
||||||
|
"is_decoder": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
301
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
Normal file
301
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
'''
|
||||||
|
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
||||||
|
* Based on timm code base
|
||||||
|
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from timm.models.vision_transformer import _cfg, PatchEmbed
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath
|
||||||
|
from timm.models.helpers import named_apply, adapt_input_conv
|
||||||
|
|
||||||
|
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
|
"""
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.attn_gradients = None
|
||||||
|
self.attention_map = None
|
||||||
|
|
||||||
|
def save_attn_gradients(self, attn_gradients):
|
||||||
|
self.attn_gradients = attn_gradients
|
||||||
|
|
||||||
|
def get_attn_gradients(self):
|
||||||
|
return self.attn_gradients
|
||||||
|
|
||||||
|
def save_attention_map(self, attention_map):
|
||||||
|
self.attention_map = attention_map
|
||||||
|
|
||||||
|
def get_attention_map(self):
|
||||||
|
return self.attention_map
|
||||||
|
|
||||||
|
def forward(self, x, register_hook=False):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
if register_hook:
|
||||||
|
self.save_attention_map(attn)
|
||||||
|
attn.register_hook(self.save_attn_gradients)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||||
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# if use_grad_checkpointing:
|
||||||
|
# self.attn = checkpoint_wrapper(self.attn)
|
||||||
|
# self.mlp = checkpoint_wrapper(self.mlp)
|
||||||
|
|
||||||
|
def forward(self, x, register_hook=False):
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
""" Vision Transformer
|
||||||
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
||||||
|
https://arxiv.org/abs/2010.11929
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
||||||
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
||||||
|
use_grad_checkpointing=False, ckpt_layer=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_size (int, tuple): input image size
|
||||||
|
patch_size (int, tuple): patch size
|
||||||
|
in_chans (int): number of input channels
|
||||||
|
num_classes (int): number of classes for classification head
|
||||||
|
embed_dim (int): embedding dimension
|
||||||
|
depth (int): depth of transformer
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||||
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||||
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||||
|
drop_rate (float): dropout rate
|
||||||
|
attn_drop_rate (float): attention dropout rate
|
||||||
|
drop_path_rate (float): stochastic depth rate
|
||||||
|
norm_layer: (nn.Module): normalization layer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
|
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Block(
|
||||||
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||||
|
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
||||||
|
)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'pos_embed', 'cls_token'}
|
||||||
|
|
||||||
|
def forward(self, x, register_blk=-1):
|
||||||
|
B = x.shape[0]
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
x = x + self.pos_embed[:,:x.size(1),:]
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
for i,blk in enumerate(self.blocks):
|
||||||
|
x = blk(x, register_blk==i)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.jit.ignore()
|
||||||
|
def load_pretrained(self, checkpoint_path, prefix=''):
|
||||||
|
_load_weights(self, checkpoint_path, prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||||
|
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def _n2p(w, t=True):
|
||||||
|
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||||
|
w = w.flatten()
|
||||||
|
if t:
|
||||||
|
if w.ndim == 4:
|
||||||
|
w = w.transpose([3, 2, 0, 1])
|
||||||
|
elif w.ndim == 3:
|
||||||
|
w = w.transpose([2, 0, 1])
|
||||||
|
elif w.ndim == 2:
|
||||||
|
w = w.transpose([1, 0])
|
||||||
|
return torch.from_numpy(w)
|
||||||
|
|
||||||
|
w = np.load(checkpoint_path)
|
||||||
|
if not prefix and 'opt/target/embedding/kernel' in w:
|
||||||
|
prefix = 'opt/target/'
|
||||||
|
|
||||||
|
if hasattr(model.patch_embed, 'backbone'):
|
||||||
|
# hybrid
|
||||||
|
backbone = model.patch_embed.backbone
|
||||||
|
stem_only = not hasattr(backbone, 'stem')
|
||||||
|
stem = backbone if stem_only else backbone.stem
|
||||||
|
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
||||||
|
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
||||||
|
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
||||||
|
if not stem_only:
|
||||||
|
for i, stage in enumerate(backbone.stages):
|
||||||
|
for j, block in enumerate(stage.blocks):
|
||||||
|
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
||||||
|
for r in range(3):
|
||||||
|
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
||||||
|
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
||||||
|
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
||||||
|
if block.downsample is not None:
|
||||||
|
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
||||||
|
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
||||||
|
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
||||||
|
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
||||||
|
else:
|
||||||
|
embed_conv_w = adapt_input_conv(
|
||||||
|
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||||
|
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||||
|
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||||
|
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||||
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||||
|
if pos_embed_w.shape != model.pos_embed.shape:
|
||||||
|
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||||
|
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||||
|
model.pos_embed.copy_(pos_embed_w)
|
||||||
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||||
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||||
|
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||||
|
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||||
|
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||||
|
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||||
|
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||||
|
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||||
|
for i, block in enumerate(model.blocks.children()):
|
||||||
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||||
|
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
||||||
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||||
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||||
|
block.attn.qkv.weight.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||||
|
block.attn.qkv.bias.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||||
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||||
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||||
|
for r in range(2):
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
||||||
|
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
||||||
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
||||||
|
# interpolate position embedding
|
||||||
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||||
|
num_patches = visual_encoder.patch_embed.num_patches
|
||||||
|
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
||||||
|
# height (== width) for the checkpoint position embedding
|
||||||
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||||
|
# height (== width) for the new position embedding
|
||||||
|
new_size = int(num_patches ** 0.5)
|
||||||
|
|
||||||
|
if orig_size!=new_size:
|
||||||
|
# class_token and dist_token are kept unchanged
|
||||||
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||||
|
# only the position tokens are interpolated
|
||||||
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||||
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||||
|
pos_tokens = torch.nn.functional.interpolate(
|
||||||
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||||
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||||
|
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
||||||
|
|
||||||
|
return new_pos_embed
|
||||||
|
else:
|
||||||
|
return pos_embed_checkpoint
|
||||||
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from modelscope import snapshot_download
|
||||||
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
import os
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
||||||
|
|
||||||
|
|
||||||
|
preference_model_id: TypeAlias = Literal[
|
||||||
|
"ImageReward",
|
||||||
|
"Aesthetic",
|
||||||
|
"PickScore",
|
||||||
|
"CLIP",
|
||||||
|
"HPSv2",
|
||||||
|
"HPSv2.1",
|
||||||
|
"MPS",
|
||||||
|
]
|
||||||
|
model_dict = {
|
||||||
|
"ImageReward": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"ImageReward/ImageReward.safetensors",
|
||||||
|
"ImageReward/med_config.json",
|
||||||
|
"bert-base-uncased/config.json",
|
||||||
|
"bert-base-uncased/model.safetensors",
|
||||||
|
"bert-base-uncased/tokenizer.json",
|
||||||
|
"bert-base-uncased/tokenizer_config.json",
|
||||||
|
"bert-base-uncased/vocab.txt",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"imagereward": "ImageReward/ImageReward.safetensors",
|
||||||
|
"med_config": "ImageReward/med_config.json",
|
||||||
|
"bert_model_path": "bert-base-uncased",
|
||||||
|
},
|
||||||
|
"model_class": ImageRewardScore
|
||||||
|
},
|
||||||
|
"Aesthetic": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||||
|
"clip-vit-large-patch14/config.json",
|
||||||
|
"clip-vit-large-patch14/merges.txt",
|
||||||
|
"clip-vit-large-patch14/model.safetensors",
|
||||||
|
"clip-vit-large-patch14/preprocessor_config.json",
|
||||||
|
"clip-vit-large-patch14/special_tokens_map.json",
|
||||||
|
"clip-vit-large-patch14/tokenizer.json",
|
||||||
|
"clip-vit-large-patch14/tokenizer_config.json",
|
||||||
|
"clip-vit-large-patch14/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||||
|
"clip-large": "clip-vit-large-patch14",
|
||||||
|
},
|
||||||
|
"model_class": AestheticScore
|
||||||
|
},
|
||||||
|
"PickScore": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"PickScore_v1/*",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"pickscore": "PickScore_v1",
|
||||||
|
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||||
|
},
|
||||||
|
"model_class": PickScore
|
||||||
|
},
|
||||||
|
"CLIP": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": CLIPScore
|
||||||
|
},
|
||||||
|
"HPSv2": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"HPS_v2/HPS_v2_compressed.safetensors",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": HPScore_v2,
|
||||||
|
"extra_kwargs": {"model_version": "v2"}
|
||||||
|
},
|
||||||
|
"HPSv2.1": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": HPScore_v2,
|
||||||
|
"extra_kwargs": {"model_version": "v21"}
|
||||||
|
},
|
||||||
|
"MPS": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||||
|
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||||
|
},
|
||||||
|
"model_class": MPScore
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
||||||
|
metadata = model_dict[model_name]
|
||||||
|
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
||||||
|
load_path = metadata["load_path"]
|
||||||
|
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
||||||
|
return load_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
||||||
|
model_class = model_dict[model_name]["model_class"]
|
||||||
|
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
||||||
|
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
||||||
|
return preference_model
|
||||||
148
diffsynth/extensions/ImageQualityMetric/aesthetic.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/aesthetic.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
import os
|
||||||
|
from typing import Union, List
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.xcol = xcol
|
||||||
|
self.ycol = ycol
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.input_size, 1024),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(1024, 128),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(128, 64),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.1),
|
||||||
|
torch.nn.Linear(64, 16),
|
||||||
|
#torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(16, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
||||||
|
x = batch[self.xcol]
|
||||||
|
y = batch[self.ycol].reshape(-1, 1)
|
||||||
|
x_hat = self.layers(x)
|
||||||
|
loss = torch.nn.functional.mse_loss(x_hat, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
||||||
|
x = batch[self.xcol]
|
||||||
|
y = batch[self.ycol].reshape(-1, 1)
|
||||||
|
x_hat = self.layers(x)
|
||||||
|
loss = torch.nn.functional.mse_loss(x_hat, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> torch.optim.Optimizer:
|
||||||
|
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class AestheticScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.aes_model_path = path.get("aesthetic_predictor")
|
||||||
|
# Load the MLP model
|
||||||
|
self.model = MLP(768)
|
||||||
|
try:
|
||||||
|
if self.aes_model_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_file(self.aes_model_path)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(self.aes_model_path)
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
||||||
|
|
||||||
|
self.model.to(device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Load the CLIP model and processor
|
||||||
|
clip_model_name = path.get('clip-large')
|
||||||
|
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor) -> float:
|
||||||
|
"""Calculate the aesthetic score for a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The aesthetic score.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Get image embeddings
|
||||||
|
image_embs = self.model2.get_image_features(image)
|
||||||
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Compute score
|
||||||
|
score = self.model(image_embs).cpu().flatten().item()
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||||
|
"""Score the images based on their aesthetic quality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
pil_image = Image.open(images)
|
||||||
|
else:
|
||||||
|
pil_image = images
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
return [self._calculate_score(image_inputs["pixel_values"])]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_image in images:
|
||||||
|
if isinstance(one_image, str):
|
||||||
|
pil_image = Image.open(one_image)
|
||||||
|
elif isinstance(one_image, Image.Image):
|
||||||
|
pil_image = one_image
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
scores.append(self._calculate_score(image_inputs["pixel_values"]))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
97
diffsynth/extensions/ImageQualityMetric/clip.py
Normal file
97
diffsynth/extensions/ImageQualityMetric/clip.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class CLIPScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the CLIPScore with a model and tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): The device to load the model on.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Create model and transforms
|
||||||
|
self.model, _, self.preprocess_val = create_model_and_transforms(
|
||||||
|
"ViT-H-14",
|
||||||
|
# "laion2B-s32B-b79K",
|
||||||
|
pretrained=path.get("open_clip"),
|
||||||
|
precision="amp",
|
||||||
|
device=device,
|
||||||
|
jit=False,
|
||||||
|
force_quick_gelu=False,
|
||||||
|
force_custom_text=False,
|
||||||
|
force_patch_dropout=False,
|
||||||
|
force_image_size=None,
|
||||||
|
pretrained_image=False,
|
||||||
|
image_mean=None,
|
||||||
|
image_std=None,
|
||||||
|
light_augmentation=True,
|
||||||
|
aug_cfg={},
|
||||||
|
output_dict=True,
|
||||||
|
with_score_predictor=False,
|
||||||
|
with_region_predictor=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize tokenizer
|
||||||
|
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
"""Calculate the CLIP score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The CLIP score.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Process the prompt
|
||||||
|
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Calculate the CLIP score
|
||||||
|
outputs = self.model(image, text)
|
||||||
|
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
||||||
|
logits_per_image = image_features @ text_features.T
|
||||||
|
clip_score = torch.diagonal(logits_per_image).cpu().numpy()
|
||||||
|
|
||||||
|
return clip_score[0].item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of CLIP scores for the images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
return [self._calculate_score(image, prompt)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
23
diffsynth/extensions/ImageQualityMetric/config.py
Normal file
23
diffsynth/extensions/ImageQualityMetric/config.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
||||||
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model_name):
|
||||||
|
return os.path.join(model_path, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_PATHS = {
|
||||||
|
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
||||||
|
"open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
||||||
|
"hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
|
||||||
|
"hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
|
||||||
|
"imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
|
||||||
|
"med_config": get_model_path("ImageReward/med_config.json"),
|
||||||
|
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||||
|
"clip-large": get_model_path("clip-vit-large-patch14"),
|
||||||
|
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||||
|
"pickscore": get_model_path("PickScore_v1")
|
||||||
|
}
|
||||||
118
diffsynth/extensions/ImageQualityMetric/hps.py
Normal file
118
diffsynth/extensions/ImageQualityMetric/hps.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
import os
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class HPScore_v2(torch.nn.Module):
|
||||||
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the Selector with a model and tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): The device to load the model on.
|
||||||
|
model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
if model_version == "v2":
|
||||||
|
safetensors_path = path.get("hpsv2")
|
||||||
|
elif model_version == "v21":
|
||||||
|
safetensors_path = path.get("hpsv2.1")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
||||||
|
|
||||||
|
# Create model and transforms
|
||||||
|
model, _, self.preprocess_val = create_model_and_transforms(
|
||||||
|
"ViT-H-14",
|
||||||
|
# "laion2B-s32B-b79K",
|
||||||
|
pretrained=path.get("open_clip"),
|
||||||
|
precision="amp",
|
||||||
|
device=device,
|
||||||
|
jit=False,
|
||||||
|
force_quick_gelu=False,
|
||||||
|
force_custom_text=False,
|
||||||
|
force_patch_dropout=False,
|
||||||
|
force_image_size=None,
|
||||||
|
pretrained_image=False,
|
||||||
|
image_mean=None,
|
||||||
|
image_std=None,
|
||||||
|
light_augmentation=True,
|
||||||
|
aug_cfg={},
|
||||||
|
output_dict=True,
|
||||||
|
with_score_predictor=False,
|
||||||
|
with_region_predictor=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
try:
|
||||||
|
state_dict = load_file(safetensors_path)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
||||||
|
|
||||||
|
# Initialize tokenizer and model
|
||||||
|
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
"""Calculate the HPS score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The HPS score.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Process the prompt
|
||||||
|
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Calculate the HPS score
|
||||||
|
outputs = self.model(image, text)
|
||||||
|
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
||||||
|
logits_per_image = image_features @ text_features.T
|
||||||
|
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
||||||
|
|
||||||
|
return hps_score[0].item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of HPS scores for the images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
return [self._calculate_score(image, prompt)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
212
diffsynth/extensions/ImageQualityMetric/imagereward.py
Normal file
212
diffsynth/extensions/ImageQualityMetric/imagereward.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Union
|
||||||
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
|
from .BLIP.blip_pretrain import BLIP_Pretrain
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
BICUBIC = InterpolationMode.BICUBIC
|
||||||
|
|
||||||
|
def _convert_image_to_rgb(image):
|
||||||
|
return image.convert("RGB")
|
||||||
|
|
||||||
|
def _transform(n_px):
|
||||||
|
return Compose([
|
||||||
|
Resize(n_px, interpolation=BICUBIC),
|
||||||
|
CenterCrop(n_px),
|
||||||
|
_convert_image_to_rgb,
|
||||||
|
ToTensor(),
|
||||||
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
])
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, input_size):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.input_size, 1024),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(1024, 128),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(128, 64),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Dropout(0.1),
|
||||||
|
torch.nn.Linear(64, 16),
|
||||||
|
#nn.ReLU(),
|
||||||
|
torch.nn.Linear(16, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# initial MLP param
|
||||||
|
for name, param in self.layers.named_parameters():
|
||||||
|
if 'weight' in name:
|
||||||
|
torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
|
||||||
|
if 'bias' in name:
|
||||||
|
torch.nn.init.constant_(param, val=0)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.layers(input)
|
||||||
|
|
||||||
|
class ImageReward(torch.nn.Module):
|
||||||
|
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
||||||
|
self.preprocess = _transform(224)
|
||||||
|
self.mlp = MLP(768)
|
||||||
|
|
||||||
|
self.mean = 0.16717362830052426
|
||||||
|
self.std = 1.0333394966054072
|
||||||
|
|
||||||
|
def score_grad(self, prompt_ids, prompt_attention_mask, image):
|
||||||
|
"""Calculate the score with gradient for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_ids (torch.Tensor): Tokenized prompt IDs.
|
||||||
|
prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The reward score.
|
||||||
|
"""
|
||||||
|
image_embeds = self.blip.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
text_output = self.blip.text_encoder(
|
||||||
|
prompt_ids,
|
||||||
|
attention_mask=prompt_attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
txt_features = text_output.last_hidden_state[:, 0, :]
|
||||||
|
rewards = self.mlp(txt_features)
|
||||||
|
rewards = (rewards - self.mean) / self.std
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
pil_image = Image.open(images)
|
||||||
|
else:
|
||||||
|
pil_image = images
|
||||||
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||||
|
return [self._calculate_score(prompt, image).item()]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_image in images:
|
||||||
|
if isinstance(one_image, str):
|
||||||
|
pil_image = Image.open(one_image)
|
||||||
|
elif isinstance(one_image, Image.Image):
|
||||||
|
pil_image = one_image
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||||
|
scores.append(self._calculate_score(prompt, image).item())
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
|
||||||
|
def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate the score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The reward score.
|
||||||
|
"""
|
||||||
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
||||||
|
image_embeds = self.blip.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
text_output = self.blip.text_encoder(
|
||||||
|
text_input.input_ids,
|
||||||
|
attention_mask=text_input.attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
txt_features = text_output.last_hidden_state[:, 0, :].float()
|
||||||
|
rewards = self.mlp(txt_features)
|
||||||
|
rewards = (rewards - self.mean) / self.std
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
|
||||||
|
"""Rank the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
|
||||||
|
"""
|
||||||
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
||||||
|
txt_set = []
|
||||||
|
for generation in generations_list:
|
||||||
|
if isinstance(generation, str):
|
||||||
|
pil_image = Image.open(generation)
|
||||||
|
elif isinstance(generation, Image.Image):
|
||||||
|
pil_image = generation
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter generations_list is illegal.")
|
||||||
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||||
|
image_embeds = self.blip.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
text_output = self.blip.text_encoder(
|
||||||
|
text_input.input_ids,
|
||||||
|
attention_mask=text_input.attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
||||||
|
txt_features = torch.cat(txt_set, 0).float()
|
||||||
|
rewards = self.mlp(txt_features)
|
||||||
|
rewards = (rewards - self.mean) / self.std
|
||||||
|
rewards = torch.squeeze(rewards)
|
||||||
|
_, rank = torch.sort(rewards, dim=0, descending=True)
|
||||||
|
_, indices = torch.sort(rank, dim=0)
|
||||||
|
indices = indices + 1
|
||||||
|
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRewardScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
|
model_path = path.get("imagereward")
|
||||||
|
med_config = path.get("med_config")
|
||||||
|
state_dict = load_file(model_path)
|
||||||
|
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
||||||
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
return self.model.score(images, prompt)
|
||||||
129
diffsynth/extensions/ImageQualityMetric/mps.py
Normal file
129
diffsynth/extensions/ImageQualityMetric/mps.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
||||||
|
from transformers import CLIPConfig
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers import CLIPModel as HFCLIPModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch import nn, einsum
|
||||||
|
|
||||||
|
from .trainer.models.base_model import BaseModelConfig
|
||||||
|
|
||||||
|
from transformers import CLIPConfig
|
||||||
|
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
||||||
|
from typing import Any, Optional, Tuple, Union, List
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .trainer.models.cross_modeling import Cross_model
|
||||||
|
from .trainer.models import clip_model
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class MPScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
processor_name_or_path = path.get("clip")
|
||||||
|
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||||
|
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
||||||
|
state_dict = load_file(path.get("mps"))
|
||||||
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
self.model.to(device)
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
"""Calculate the reward score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The reward score.
|
||||||
|
"""
|
||||||
|
def _tokenize(caption):
|
||||||
|
input_ids = self.tokenizer(
|
||||||
|
caption,
|
||||||
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).input_ids
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
text_input = _tokenize(prompt).to(self.device)
|
||||||
|
if self.condition == 'overall':
|
||||||
|
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
|
||||||
|
elif self.condition == 'aesthetics':
|
||||||
|
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
|
||||||
|
elif self.condition == 'quality':
|
||||||
|
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
|
||||||
|
elif self.condition == 'semantic':
|
||||||
|
condition_prompt = 'quantity, attributes, position, number, location'
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
|
||||||
|
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
text_f, text_features = self.model.model.get_text_features(text_input)
|
||||||
|
|
||||||
|
image_f = self.model.model.get_image_features(image.half())
|
||||||
|
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
||||||
|
|
||||||
|
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||||
|
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||||
|
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||||
|
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
||||||
|
mask = mask.repeat(1, image_f.shape[1], 1)
|
||||||
|
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
||||||
|
|
||||||
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||||
|
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
||||||
|
|
||||||
|
return image_score[0].cpu().numpy().item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of reward scores for the images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
else:
|
||||||
|
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
return [self._calculate_score(image, prompt)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from .coca_model import CoCa
|
||||||
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||||
|
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
||||||
|
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
||||||
|
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||||
|
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
||||||
|
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
||||||
|
from .openai import load_openai_model, list_openai_models
|
||||||
|
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
||||||
|
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
||||||
|
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
||||||
|
from .tokenizer import SimpleTokenizer
|
||||||
|
from .transform import image_transform, AugmentationCfg
|
||||||
|
from .utils import freeze_batch_norm_2d
|
||||||
458
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
Normal file
458
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import numpy as np
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .transformer import (
|
||||||
|
LayerNormFp32,
|
||||||
|
LayerNorm,
|
||||||
|
QuickGELU,
|
||||||
|
MultimodalTransformer,
|
||||||
|
)
|
||||||
|
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import (
|
||||||
|
BeamSearchScorer,
|
||||||
|
LogitsProcessorList,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
MinLengthLogitsProcessor,
|
||||||
|
MaxLengthCriteria,
|
||||||
|
StoppingCriteriaList
|
||||||
|
)
|
||||||
|
|
||||||
|
GENERATION_TYPES = {
|
||||||
|
"top_k": TopKLogitsWarper,
|
||||||
|
"top_p": TopPLogitsWarper,
|
||||||
|
"beam_search": "beam_search"
|
||||||
|
}
|
||||||
|
_has_transformers = True
|
||||||
|
except ImportError as e:
|
||||||
|
GENERATION_TYPES = {
|
||||||
|
"top_k": None,
|
||||||
|
"top_p": None,
|
||||||
|
"beam_search": "beam_search"
|
||||||
|
}
|
||||||
|
_has_transformers = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultimodalCfg(CLIPTextCfg):
|
||||||
|
mlp_ratio: int = 4
|
||||||
|
dim_head: int = 64
|
||||||
|
heads: int = 8
|
||||||
|
n_queries: int = 256
|
||||||
|
attn_pooler_heads: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_decoder_tower(
|
||||||
|
embed_dim,
|
||||||
|
multimodal_cfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
||||||
|
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||||
|
norm_layer = (
|
||||||
|
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = MultimodalTransformer(
|
||||||
|
context_length=multimodal_cfg.context_length,
|
||||||
|
width=multimodal_cfg.width,
|
||||||
|
heads=multimodal_cfg.heads,
|
||||||
|
layers=multimodal_cfg.layers,
|
||||||
|
ls_init_value=multimodal_cfg.ls_init_value,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
|
class CoCa(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
multimodal_cfg: MultimodalCfg,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
pad_id: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
||||||
|
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
||||||
|
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
||||||
|
|
||||||
|
self.text = _build_text_tower(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
text_cfg=text_cfg,
|
||||||
|
quick_gelu=quick_gelu,
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocab_size = (
|
||||||
|
text_cfg.vocab_size # for hf models
|
||||||
|
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
||||||
|
else text_cfg.vocab_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual = _build_vision_tower(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
vision_cfg=vision_cfg,
|
||||||
|
quick_gelu=quick_gelu,
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_decoder = _build_text_decoder_tower(
|
||||||
|
vocab_size,
|
||||||
|
multimodal_cfg=multimodal_cfg,
|
||||||
|
quick_gelu=quick_gelu,
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
self.pad_id = pad_id
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.visual.set_grad_checkpointing(enable)
|
||||||
|
self.text.set_grad_checkpointing(enable)
|
||||||
|
self.text_decoder.set_grad_checkpointing(enable)
|
||||||
|
|
||||||
|
def _encode_image(self, images, normalize=True):
|
||||||
|
image_latent, tokens_embs = self.visual(images)
|
||||||
|
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
||||||
|
return image_latent, tokens_embs
|
||||||
|
|
||||||
|
def _encode_text(self, text, normalize=True, embed_cls=True):
|
||||||
|
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
||||||
|
text_latent, token_emb = self.text(text)
|
||||||
|
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
||||||
|
return text_latent, token_emb
|
||||||
|
|
||||||
|
def encode_image(self, images, normalize=True):
|
||||||
|
image_latent, _ = self._encode_image(images, normalize=normalize)
|
||||||
|
return image_latent
|
||||||
|
|
||||||
|
def encode_text(self, text, normalize=True, embed_cls=True):
|
||||||
|
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
||||||
|
return text_latent
|
||||||
|
|
||||||
|
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
||||||
|
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
||||||
|
if image_latent is None or image_embs is None:
|
||||||
|
image_latent, image_embs = self._encode_image(image)
|
||||||
|
|
||||||
|
# TODO: add assertion to avoid bugs?
|
||||||
|
labels = text[:, -token_embs.shape[1]:]
|
||||||
|
|
||||||
|
logits = self.text_decoder(image_embs, token_embs)
|
||||||
|
return {
|
||||||
|
"image_features": image_latent,
|
||||||
|
"text_features": text_latent,
|
||||||
|
"logits": logits,
|
||||||
|
"labels": labels,
|
||||||
|
"logit_scale": self.logit_scale.exp()
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
text=None,
|
||||||
|
seq_len=30,
|
||||||
|
max_seq_len=77,
|
||||||
|
temperature=1.,
|
||||||
|
generation_type="beam_search",
|
||||||
|
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
||||||
|
top_k=1, # keeps the top_k most probable tokens
|
||||||
|
pad_token_id=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
sot_token_id=None,
|
||||||
|
num_beams=6,
|
||||||
|
num_beam_groups=3,
|
||||||
|
min_seq_len=5,
|
||||||
|
stopping_criteria=None,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
||||||
|
):
|
||||||
|
# taking many ideas and components from HuggingFace GenerationMixin
|
||||||
|
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
||||||
|
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
||||||
|
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
||||||
|
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
||||||
|
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
||||||
|
logit_processor = LogitsProcessorList(
|
||||||
|
[
|
||||||
|
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
||||||
|
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if stopping_criteria is None:
|
||||||
|
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList(
|
||||||
|
stopping_criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
device = image.device
|
||||||
|
|
||||||
|
if generation_type == "beam_search":
|
||||||
|
output = self._generate_beamsearch(
|
||||||
|
image_inputs = image,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
sot_token_id=sot_token_id,
|
||||||
|
num_beams=num_beams,
|
||||||
|
num_beam_groups=num_beam_groups,
|
||||||
|
min_seq_len=min_seq_len,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
logit_processor=logit_processor,
|
||||||
|
)
|
||||||
|
if fixed_output_length and output.shape[1] < seq_len:
|
||||||
|
return torch.cat(
|
||||||
|
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
elif generation_type == "top_p":
|
||||||
|
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
||||||
|
elif generation_type == "top_k":
|
||||||
|
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"generation_type has to be one of "
|
||||||
|
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_latent, image_embs = self._encode_image(image)
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
||||||
|
|
||||||
|
was_training = self.training
|
||||||
|
num_dims = len(text.shape)
|
||||||
|
|
||||||
|
if num_dims == 1:
|
||||||
|
text = text[None, :]
|
||||||
|
|
||||||
|
cur_len = text.shape[1]
|
||||||
|
self.eval()
|
||||||
|
out = text
|
||||||
|
|
||||||
|
while True:
|
||||||
|
x = out[:, -max_seq_len:]
|
||||||
|
cur_len = x.shape[1]
|
||||||
|
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
||||||
|
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
||||||
|
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
||||||
|
|
||||||
|
if mask.all():
|
||||||
|
if not fixed_output_length:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logits = logits[~mask, :]
|
||||||
|
filtered_logits = logit_processor(x[~mask, :], logits)
|
||||||
|
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
||||||
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
if (cur_len + 1 == seq_len):
|
||||||
|
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
||||||
|
else:
|
||||||
|
sample[~mask, :] = torch.multinomial(probs, 1)
|
||||||
|
|
||||||
|
out = torch.cat((out, sample), dim=-1)
|
||||||
|
|
||||||
|
cur_len += 1
|
||||||
|
|
||||||
|
if stopping_criteria(out, None):
|
||||||
|
break
|
||||||
|
|
||||||
|
if num_dims == 1:
|
||||||
|
out = out.squeeze(0)
|
||||||
|
|
||||||
|
self.train(was_training)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _generate_beamsearch(
|
||||||
|
self,
|
||||||
|
image_inputs,
|
||||||
|
pad_token_id=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
sot_token_id=None,
|
||||||
|
num_beams=6,
|
||||||
|
num_beam_groups=3,
|
||||||
|
min_seq_len=5,
|
||||||
|
stopping_criteria=None,
|
||||||
|
logit_processor=None,
|
||||||
|
logit_warper=None,
|
||||||
|
):
|
||||||
|
device = image_inputs.device
|
||||||
|
batch_size = image_inputs.shape[0]
|
||||||
|
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
||||||
|
image_latent, image_embs = self._encode_image(image_inputs)
|
||||||
|
|
||||||
|
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
||||||
|
input_ids = input_ids * sot_token_id
|
||||||
|
beam_scorer = BeamSearchScorer(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
|
device=device,
|
||||||
|
num_beam_groups=num_beam_groups,
|
||||||
|
)
|
||||||
|
# instantiate logits processors
|
||||||
|
logits_processor = (
|
||||||
|
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
||||||
|
if logit_processor is None
|
||||||
|
else logit_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
|
num_beams = beam_scorer.num_beams
|
||||||
|
num_beam_groups = beam_scorer.num_beam_groups
|
||||||
|
num_sub_beams = num_beams // num_beam_groups
|
||||||
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
|
beam_indices = None
|
||||||
|
|
||||||
|
if num_beams * batch_size != batch_beam_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||||
|
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||||
|
# the same group don't produce same tokens everytime.
|
||||||
|
beam_scores[:, ::num_sub_beams] = 0
|
||||||
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
# predicted tokens in cur_len step
|
||||||
|
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||||
|
|
||||||
|
# indices which will form the beams in the next time step
|
||||||
|
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# do one decoder step on all beams of all sentences in batch
|
||||||
|
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
||||||
|
outputs = self(
|
||||||
|
model_inputs['images'],
|
||||||
|
model_inputs['text'],
|
||||||
|
embed_cls=False,
|
||||||
|
image_latent=image_latent,
|
||||||
|
image_embs=image_embs
|
||||||
|
)
|
||||||
|
|
||||||
|
for beam_group_idx in range(num_beam_groups):
|
||||||
|
group_start_idx = beam_group_idx * num_sub_beams
|
||||||
|
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||||
|
group_size = group_end_idx - group_start_idx
|
||||||
|
|
||||||
|
# indices of beams of current group among all sentences in batch
|
||||||
|
batch_group_indices = []
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_group_indices.extend(
|
||||||
|
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||||
|
)
|
||||||
|
group_input_ids = input_ids[batch_group_indices]
|
||||||
|
|
||||||
|
# select outputs of beams of currentg group only
|
||||||
|
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
||||||
|
vocab_size = next_token_logits.shape[-1]
|
||||||
|
|
||||||
|
next_token_scores_processed = logits_processor(
|
||||||
|
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||||
|
)
|
||||||
|
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
||||||
|
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
||||||
|
|
||||||
|
# reshape for beam search
|
||||||
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||||
|
|
||||||
|
next_token_scores, next_tokens = torch.topk(
|
||||||
|
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
||||||
|
)
|
||||||
|
|
||||||
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
||||||
|
next_tokens = next_tokens % vocab_size
|
||||||
|
|
||||||
|
# stateless
|
||||||
|
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||||
|
beam_outputs = beam_scorer.process(
|
||||||
|
group_input_ids,
|
||||||
|
next_token_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
beam_indices=process_beam_indices,
|
||||||
|
)
|
||||||
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||||
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
|
beam_idx = beam_outputs["next_beam_indices"]
|
||||||
|
|
||||||
|
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||||
|
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||||
|
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||||
|
|
||||||
|
# (beam_idx // group_size) -> batch_idx
|
||||||
|
# (beam_idx % group_size) -> offset of idx inside the group
|
||||||
|
reordering_indices[batch_group_indices] = (
|
||||||
|
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
# increase cur_len
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
||||||
|
break
|
||||||
|
|
||||||
|
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||||
|
sequence_outputs = beam_scorer.finalize(
|
||||||
|
input_ids,
|
||||||
|
beam_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
max_length=stopping_criteria.max_length,
|
||||||
|
beam_indices=final_beam_indices,
|
||||||
|
)
|
||||||
|
return sequence_outputs['sequences']
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
|
return {
|
||||||
|
"text": input_ids,
|
||||||
|
"images": image_inputs,
|
||||||
|
"past_key_values": past,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||||
|
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||||
433
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
Normal file
433
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
# from turtle import forward
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||||
|
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
||||||
|
resize_pos_embed, get_cast_dtype
|
||||||
|
from .coca_model import CoCa
|
||||||
|
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||||
|
from .openai import load_openai_model
|
||||||
|
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
||||||
|
from .transform import image_transform, AugmentationCfg
|
||||||
|
from .tokenizer import HFTokenizer, SimpleTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
HF_HUB_PREFIX = 'hf-hub:'
|
||||||
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
||||||
|
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
||||||
|
|
||||||
|
|
||||||
|
def _natural_key(string_):
|
||||||
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||||
|
|
||||||
|
|
||||||
|
def _rescan_model_configs():
|
||||||
|
global _MODEL_CONFIGS
|
||||||
|
|
||||||
|
config_ext = ('.json',)
|
||||||
|
config_files = []
|
||||||
|
for config_path in _MODEL_CONFIG_PATHS:
|
||||||
|
if config_path.is_file() and config_path.suffix in config_ext:
|
||||||
|
config_files.append(config_path)
|
||||||
|
elif config_path.is_dir():
|
||||||
|
for ext in config_ext:
|
||||||
|
config_files.extend(config_path.glob(f'*{ext}'))
|
||||||
|
|
||||||
|
for cf in config_files:
|
||||||
|
with open(cf, 'r') as f:
|
||||||
|
model_cfg = json.load(f)
|
||||||
|
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
||||||
|
_MODEL_CONFIGS[cf.stem] = model_cfg
|
||||||
|
|
||||||
|
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
||||||
|
|
||||||
|
|
||||||
|
_rescan_model_configs() # initial populate of model config registry
|
||||||
|
|
||||||
|
|
||||||
|
def list_models():
|
||||||
|
""" enumerate available model architectures based on config files """
|
||||||
|
return list(_MODEL_CONFIGS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_config(path):
|
||||||
|
""" add model config path or file and update registry """
|
||||||
|
if not isinstance(path, Path):
|
||||||
|
path = Path(path)
|
||||||
|
_MODEL_CONFIG_PATHS.append(path)
|
||||||
|
_rescan_model_configs()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name):
|
||||||
|
if model_name in _MODEL_CONFIGS:
|
||||||
|
return deepcopy(_MODEL_CONFIGS[model_name])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
||||||
|
if model_name.startswith(HF_HUB_PREFIX):
|
||||||
|
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
||||||
|
else:
|
||||||
|
config = get_model_config(model_name)
|
||||||
|
tokenizer = HFTokenizer(
|
||||||
|
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
if next(iter(state_dict.items()))[0].startswith('module'):
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path, strict=True):
|
||||||
|
state_dict = load_state_dict(checkpoint_path)
|
||||||
|
# detect old format and make compatible with new format
|
||||||
|
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
||||||
|
state_dict = convert_to_custom_text_state_dict(state_dict)
|
||||||
|
resize_pos_embed(state_dict, model)
|
||||||
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||||
|
return incompatible_keys
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
model_name: str,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
precision: str = 'fp32',
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
jit: bool = False,
|
||||||
|
force_quick_gelu: bool = False,
|
||||||
|
force_custom_text: bool = False,
|
||||||
|
force_patch_dropout: Optional[float] = None,
|
||||||
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
pretrained_image: bool = False,
|
||||||
|
pretrained_hf: bool = True,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
output_dict: Optional[bool] = None,
|
||||||
|
require_pretrained: bool = False,
|
||||||
|
):
|
||||||
|
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
||||||
|
if has_hf_hub_prefix:
|
||||||
|
model_id = model_name[len(HF_HUB_PREFIX):]
|
||||||
|
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
||||||
|
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
pretrained_cfg = config['preprocess_cfg']
|
||||||
|
model_cfg = config['model_cfg']
|
||||||
|
else:
|
||||||
|
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
||||||
|
checkpoint_path = None
|
||||||
|
pretrained_cfg = {}
|
||||||
|
model_cfg = None
|
||||||
|
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
if pretrained and pretrained.lower() == 'openai':
|
||||||
|
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
||||||
|
model = load_openai_model(
|
||||||
|
model_name,
|
||||||
|
precision=precision,
|
||||||
|
device=device,
|
||||||
|
jit=jit,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# to always output dict even if it is clip
|
||||||
|
if output_dict and hasattr(model, "output_dict"):
|
||||||
|
model.output_dict = True
|
||||||
|
else:
|
||||||
|
model_cfg = model_cfg or get_model_config(model_name)
|
||||||
|
if model_cfg is not None:
|
||||||
|
logging.info(f'Loaded {model_name} model config.')
|
||||||
|
else:
|
||||||
|
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
||||||
|
raise RuntimeError(f'Model config for {model_name} not found.')
|
||||||
|
|
||||||
|
if force_quick_gelu:
|
||||||
|
# override for use of QuickGELU on non-OpenAI transformer models
|
||||||
|
model_cfg["quick_gelu"] = True
|
||||||
|
|
||||||
|
if force_patch_dropout is not None:
|
||||||
|
# override the default patch dropout value
|
||||||
|
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
||||||
|
|
||||||
|
if force_image_size is not None:
|
||||||
|
# override model config's image size
|
||||||
|
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
||||||
|
|
||||||
|
if pretrained_image:
|
||||||
|
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
||||||
|
# pretrained weight loading for timm models set via vision_cfg
|
||||||
|
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
||||||
|
else:
|
||||||
|
assert False, 'pretrained image towers currently only supported for timm models'
|
||||||
|
|
||||||
|
cast_dtype = get_cast_dtype(precision)
|
||||||
|
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
||||||
|
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
||||||
|
|
||||||
|
if custom_text:
|
||||||
|
if is_hf_model:
|
||||||
|
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
||||||
|
if "coca" in model_name:
|
||||||
|
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
||||||
|
else:
|
||||||
|
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
||||||
|
else:
|
||||||
|
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
||||||
|
|
||||||
|
pretrained_loaded = False
|
||||||
|
if pretrained:
|
||||||
|
checkpoint_path = ''
|
||||||
|
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
||||||
|
if pretrained_cfg:
|
||||||
|
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
||||||
|
elif os.path.exists(pretrained):
|
||||||
|
checkpoint_path = pretrained
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
||||||
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
else:
|
||||||
|
error_str = (
|
||||||
|
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
||||||
|
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
||||||
|
logging.warning(error_str)
|
||||||
|
raise RuntimeError(error_str)
|
||||||
|
pretrained_loaded = True
|
||||||
|
elif has_hf_hub_prefix:
|
||||||
|
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
||||||
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
pretrained_loaded = True
|
||||||
|
|
||||||
|
if require_pretrained and not pretrained_loaded:
|
||||||
|
# callers of create_model_from_pretrained always expect pretrained weights
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
||||||
|
|
||||||
|
model.to(device=device)
|
||||||
|
if precision in ("fp16", "bf16"):
|
||||||
|
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
|
||||||
|
|
||||||
|
# set image / mean metadata from pretrained_cfg if available, or use default
|
||||||
|
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
||||||
|
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
||||||
|
|
||||||
|
# to always output dict even if it is clip
|
||||||
|
if output_dict and hasattr(model, "output_dict"):
|
||||||
|
model.output_dict = True
|
||||||
|
|
||||||
|
if jit:
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_loss(args):
|
||||||
|
if args.distill:
|
||||||
|
return DistillClipLoss(
|
||||||
|
local_loss=args.local_loss,
|
||||||
|
gather_with_grad=args.gather_with_grad,
|
||||||
|
cache_labels=True,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
use_horovod=args.horovod,
|
||||||
|
)
|
||||||
|
elif "coca" in args.model.lower():
|
||||||
|
return CoCaLoss(
|
||||||
|
caption_loss_weight=args.coca_caption_loss_weight,
|
||||||
|
clip_loss_weight=args.coca_contrastive_loss_weight,
|
||||||
|
local_loss=args.local_loss,
|
||||||
|
gather_with_grad=args.gather_with_grad,
|
||||||
|
cache_labels=True,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
use_horovod=args.horovod,
|
||||||
|
)
|
||||||
|
return ClipLoss(
|
||||||
|
local_loss=args.local_loss,
|
||||||
|
gather_with_grad=args.gather_with_grad,
|
||||||
|
cache_labels=True,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
use_horovod=args.horovod,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, input_size):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.input_size, 1024),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(1024, 128),
|
||||||
|
torch.nn.Dropout(0.2),
|
||||||
|
torch.nn.Linear(128, 64),
|
||||||
|
torch.nn.Dropout(0.1),
|
||||||
|
torch.nn.Linear(64, 16),
|
||||||
|
torch.nn.Linear(16, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
# class semantic_head(torch.nn.Module):
|
||||||
|
# def __init__(self, input_size):
|
||||||
|
# super().__init__()
|
||||||
|
# self.input_size = input_size # for ViT-L-14 is 1024
|
||||||
|
# self.seg_head = torch.nn.Sequential(
|
||||||
|
# torch.nn.Linear(input_size, 128),
|
||||||
|
# torch.nn.Dropout(0.2),
|
||||||
|
# torch.nn.Linear(128, 64),
|
||||||
|
# torch.nn.Dropout(0.1),
|
||||||
|
# torch.nn.Linear(64, 16),
|
||||||
|
# torch.nn.Linear(16, 1),
|
||||||
|
# )
|
||||||
|
# self.sigmoid = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# return self.sigmoid(self.seg_head(x))
|
||||||
|
|
||||||
|
def create_model_and_transforms(
|
||||||
|
model_name: str,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
precision: str = 'fp32',
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
jit: bool = False,
|
||||||
|
force_quick_gelu: bool = False,
|
||||||
|
force_custom_text: bool = False,
|
||||||
|
force_patch_dropout: Optional[float] = None,
|
||||||
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
pretrained_image: bool = False,
|
||||||
|
pretrained_hf: bool = True,
|
||||||
|
image_mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
image_std: Optional[Tuple[float, ...]] = None,
|
||||||
|
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
light_augmentation = False,
|
||||||
|
output_dict: Optional[bool] = None,
|
||||||
|
with_score_predictor: bool = False,
|
||||||
|
with_region_predictor: bool = False
|
||||||
|
):
|
||||||
|
model = create_model(
|
||||||
|
model_name,
|
||||||
|
pretrained,
|
||||||
|
precision=precision,
|
||||||
|
device=device,
|
||||||
|
jit=jit,
|
||||||
|
force_quick_gelu=force_quick_gelu,
|
||||||
|
force_custom_text=force_custom_text,
|
||||||
|
force_patch_dropout=force_patch_dropout,
|
||||||
|
force_image_size=force_image_size,
|
||||||
|
pretrained_image=pretrained_image,
|
||||||
|
pretrained_hf=pretrained_hf,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
output_dict=output_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
||||||
|
image_std = image_std or getattr(model.visual, 'image_std', None)
|
||||||
|
|
||||||
|
if with_score_predictor:
|
||||||
|
model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
||||||
|
|
||||||
|
if with_region_predictor:
|
||||||
|
# model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
||||||
|
model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
|
||||||
|
# preprocess_train = image_transform_region(
|
||||||
|
# model.visual.image_size,
|
||||||
|
# is_train=True,
|
||||||
|
# mean=image_mean,
|
||||||
|
# std=image_std
|
||||||
|
# )
|
||||||
|
# preprocess_val = image_transform_region(
|
||||||
|
# model.visual.image_size,
|
||||||
|
# is_train=False,
|
||||||
|
# mean=image_mean,
|
||||||
|
# std=image_std
|
||||||
|
# )
|
||||||
|
|
||||||
|
if light_augmentation:
|
||||||
|
preprocess_val = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=False,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std,
|
||||||
|
resize_longest_max=True,
|
||||||
|
)
|
||||||
|
preprocess_train = preprocess_val
|
||||||
|
else:
|
||||||
|
preprocess_train = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=True,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std
|
||||||
|
)
|
||||||
|
preprocess_val = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=False,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, preprocess_train, preprocess_val
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_from_pretrained(
|
||||||
|
model_name: str,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
precision: str = 'fp32',
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
jit: bool = False,
|
||||||
|
force_quick_gelu: bool = False,
|
||||||
|
force_custom_text: bool = False,
|
||||||
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
return_transform: bool = True,
|
||||||
|
image_mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
image_std: Optional[Tuple[float, ...]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
model = create_model(
|
||||||
|
model_name,
|
||||||
|
pretrained,
|
||||||
|
precision=precision,
|
||||||
|
device=device,
|
||||||
|
jit=jit,
|
||||||
|
force_quick_gelu=force_quick_gelu,
|
||||||
|
force_custom_text=force_custom_text,
|
||||||
|
force_image_size=force_image_size,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
require_pretrained=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_transform:
|
||||||
|
return model
|
||||||
|
|
||||||
|
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
||||||
|
image_std = image_std or getattr(model.visual, 'image_std', None)
|
||||||
|
preprocess = image_transform(
|
||||||
|
model.visual.image_size,
|
||||||
|
is_train=False,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, preprocess
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
# HF architecture dict:
|
||||||
|
arch_dict = {
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
||||||
|
"roberta": {
|
||||||
|
"config_names": {
|
||||||
|
"context_length": "max_position_embeddings",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"width": "hidden_size",
|
||||||
|
"heads": "num_attention_heads",
|
||||||
|
"layers": "num_hidden_layers",
|
||||||
|
"layer_attr": "layer",
|
||||||
|
"token_embeddings_attr": "embeddings"
|
||||||
|
},
|
||||||
|
"pooler": "mean_pooler",
|
||||||
|
},
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
||||||
|
"xlm-roberta": {
|
||||||
|
"config_names": {
|
||||||
|
"context_length": "max_position_embeddings",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"width": "hidden_size",
|
||||||
|
"heads": "num_attention_heads",
|
||||||
|
"layers": "num_hidden_layers",
|
||||||
|
"layer_attr": "layer",
|
||||||
|
"token_embeddings_attr": "embeddings"
|
||||||
|
},
|
||||||
|
"pooler": "mean_pooler",
|
||||||
|
},
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
||||||
|
"mt5": {
|
||||||
|
"config_names": {
|
||||||
|
# unlimited seqlen
|
||||||
|
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
||||||
|
"context_length": "",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"width": "d_model",
|
||||||
|
"heads": "num_heads",
|
||||||
|
"layers": "num_layers",
|
||||||
|
"layer_attr": "block",
|
||||||
|
"token_embeddings_attr": "embed_tokens"
|
||||||
|
},
|
||||||
|
"pooler": "mean_pooler",
|
||||||
|
},
|
||||||
|
}
|
||||||
176
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
Normal file
176
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
""" huggingface model adapter
|
||||||
|
|
||||||
|
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import TensorType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
||||||
|
BaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
except ImportError as e:
|
||||||
|
transformers = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainedConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from .hf_configs import arch_dict
|
||||||
|
|
||||||
|
|
||||||
|
# utils
|
||||||
|
def _camel2snake(s):
|
||||||
|
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: ?last - for gpt-like models
|
||||||
|
_POOLERS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_pooler(cls):
|
||||||
|
"""Decorator registering pooler class"""
|
||||||
|
_POOLERS[_camel2snake(cls.__name__)] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
@register_pooler
|
||||||
|
class MeanPooler(nn.Module):
|
||||||
|
"""Mean pooling"""
|
||||||
|
|
||||||
|
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||||
|
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
||||||
|
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@register_pooler
|
||||||
|
class MaxPooler(nn.Module):
|
||||||
|
"""Max pooling"""
|
||||||
|
|
||||||
|
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||||
|
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
||||||
|
return masked_output.max(1).values
|
||||||
|
|
||||||
|
|
||||||
|
@register_pooler
|
||||||
|
class ClsPooler(nn.Module):
|
||||||
|
"""CLS token pooling"""
|
||||||
|
|
||||||
|
def __init__(self, use_pooler_output=True):
|
||||||
|
super().__init__()
|
||||||
|
self.cls_token_position = 0
|
||||||
|
self.use_pooler_output = use_pooler_output
|
||||||
|
|
||||||
|
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
||||||
|
if (self.use_pooler_output and
|
||||||
|
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
||||||
|
(x.pooler_output is not None)
|
||||||
|
):
|
||||||
|
return x.pooler_output
|
||||||
|
|
||||||
|
return x.last_hidden_state[:, self.cls_token_position, :]
|
||||||
|
|
||||||
|
|
||||||
|
class HFTextEncoder(nn.Module):
|
||||||
|
"""HuggingFace model adapter"""
|
||||||
|
output_tokens: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
output_dim: int,
|
||||||
|
config: PretrainedConfig = None,
|
||||||
|
pooler_type: str = None,
|
||||||
|
proj: str = None,
|
||||||
|
pretrained: bool = True,
|
||||||
|
output_tokens: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# TODO: find better way to get this information
|
||||||
|
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
||||||
|
|
||||||
|
if transformers is None:
|
||||||
|
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
||||||
|
if config is None:
|
||||||
|
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
||||||
|
AutoModel.from_config, self.config)
|
||||||
|
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
||||||
|
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
||||||
|
self.transformer = create_func(model_args)
|
||||||
|
self.transformer = self.transformer.encoder
|
||||||
|
else:
|
||||||
|
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
self.transformer = AutoModel.from_config(config)
|
||||||
|
if pooler_type is None: # get default arch pooler
|
||||||
|
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
||||||
|
|
||||||
|
self.pooler = _POOLERS[pooler_type]()
|
||||||
|
|
||||||
|
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
||||||
|
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
elif proj == 'linear':
|
||||||
|
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
||||||
|
elif proj == 'mlp':
|
||||||
|
hidden_size = (d_model + output_dim) // 2
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(d_model, hidden_size, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_size, output_dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: TensorType):
|
||||||
|
attn_mask = (x != self.config.pad_token_id).long()
|
||||||
|
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
||||||
|
pooled_out = self.pooler(out, attn_mask)
|
||||||
|
projected = self.proj(pooled_out)
|
||||||
|
|
||||||
|
seq_len = out.last_hidden_state.shape[1]
|
||||||
|
tokens = (
|
||||||
|
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
||||||
|
if type(self.pooler) == ClsPooler
|
||||||
|
else out.last_hidden_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.output_tokens:
|
||||||
|
return projected, tokens
|
||||||
|
return projected
|
||||||
|
|
||||||
|
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||||
|
if not unlocked_layers: # full freezing
|
||||||
|
for n, p in self.transformer.named_parameters():
|
||||||
|
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||||
|
return
|
||||||
|
|
||||||
|
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
||||||
|
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
||||||
|
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
||||||
|
embeddings = getattr(
|
||||||
|
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
||||||
|
modules = [embeddings, *layer_list][:-unlocked_layers]
|
||||||
|
# freeze layers
|
||||||
|
for module in modules:
|
||||||
|
for n, p in module.named_parameters():
|
||||||
|
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.transformer.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
pass
|
||||||
270
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
Normal file
270
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch.distributed.nn
|
||||||
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
has_distributed = True
|
||||||
|
except ImportError:
|
||||||
|
has_distributed = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import horovod.torch as hvd
|
||||||
|
except ImportError:
|
||||||
|
hvd = None
|
||||||
|
|
||||||
|
|
||||||
|
def gather_features(
|
||||||
|
image_features,
|
||||||
|
text_features,
|
||||||
|
local_loss=False,
|
||||||
|
gather_with_grad=False,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
use_horovod=False
|
||||||
|
):
|
||||||
|
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
||||||
|
if use_horovod:
|
||||||
|
assert hvd is not None, 'Please install horovod'
|
||||||
|
if gather_with_grad:
|
||||||
|
all_image_features = hvd.allgather(image_features)
|
||||||
|
all_text_features = hvd.allgather(text_features)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
all_image_features = hvd.allgather(image_features)
|
||||||
|
all_text_features = hvd.allgather(text_features)
|
||||||
|
if not local_loss:
|
||||||
|
# ensure grads for local rank when all_* features don't have a gradient
|
||||||
|
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
||||||
|
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
||||||
|
gathered_image_features[rank] = image_features
|
||||||
|
gathered_text_features[rank] = text_features
|
||||||
|
all_image_features = torch.cat(gathered_image_features, dim=0)
|
||||||
|
all_text_features = torch.cat(gathered_text_features, dim=0)
|
||||||
|
else:
|
||||||
|
# We gather tensors from all gpus
|
||||||
|
if gather_with_grad:
|
||||||
|
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
||||||
|
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
||||||
|
else:
|
||||||
|
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
||||||
|
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
||||||
|
dist.all_gather(gathered_image_features, image_features)
|
||||||
|
dist.all_gather(gathered_text_features, text_features)
|
||||||
|
if not local_loss:
|
||||||
|
# ensure grads for local rank when all_* features don't have a gradient
|
||||||
|
gathered_image_features[rank] = image_features
|
||||||
|
gathered_text_features[rank] = text_features
|
||||||
|
all_image_features = torch.cat(gathered_image_features, dim=0)
|
||||||
|
all_text_features = torch.cat(gathered_text_features, dim=0)
|
||||||
|
|
||||||
|
return all_image_features, all_text_features
|
||||||
|
|
||||||
|
|
||||||
|
class ClipLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_loss=False,
|
||||||
|
gather_with_grad=False,
|
||||||
|
cache_labels=False,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
use_horovod=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.local_loss = local_loss
|
||||||
|
self.gather_with_grad = gather_with_grad
|
||||||
|
self.cache_labels = cache_labels
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
self.use_horovod = use_horovod
|
||||||
|
|
||||||
|
# cache state
|
||||||
|
self.prev_num_logits = 0
|
||||||
|
self.labels = {}
|
||||||
|
|
||||||
|
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
||||||
|
# calculated ground-truth and cache if enabled
|
||||||
|
if self.prev_num_logits != num_logits or device not in self.labels:
|
||||||
|
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
||||||
|
if self.world_size > 1 and self.local_loss:
|
||||||
|
labels = labels + num_logits * self.rank
|
||||||
|
if self.cache_labels:
|
||||||
|
self.labels[device] = labels
|
||||||
|
self.prev_num_logits = num_logits
|
||||||
|
else:
|
||||||
|
labels = self.labels[device]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def get_logits(self, image_features, text_features, logit_scale):
|
||||||
|
if self.world_size > 1:
|
||||||
|
all_image_features, all_text_features = gather_features(
|
||||||
|
image_features, text_features,
|
||||||
|
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
||||||
|
|
||||||
|
if self.local_loss:
|
||||||
|
logits_per_image = logit_scale * image_features @ all_text_features.T
|
||||||
|
logits_per_text = logit_scale * text_features @ all_image_features.T
|
||||||
|
else:
|
||||||
|
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
||||||
|
logits_per_text = logits_per_image.T
|
||||||
|
else:
|
||||||
|
logits_per_image = logit_scale * image_features @ text_features.T
|
||||||
|
logits_per_text = logit_scale * text_features @ image_features.T
|
||||||
|
|
||||||
|
return logits_per_image, logits_per_text
|
||||||
|
|
||||||
|
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
||||||
|
device = image_features.device
|
||||||
|
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
||||||
|
|
||||||
|
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
||||||
|
|
||||||
|
total_loss = (
|
||||||
|
F.cross_entropy(logits_per_image, labels) +
|
||||||
|
F.cross_entropy(logits_per_text, labels)
|
||||||
|
) / 2
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
class PreferenceLoss(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, logits_per_image, num_images, labels):
|
||||||
|
|
||||||
|
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
||||||
|
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
|
||||||
|
|
||||||
|
ce_loss = F.cross_entropy(paired_logits, labels)
|
||||||
|
return ce_loss
|
||||||
|
|
||||||
|
class HPSLoss(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, text_logits, labels):
|
||||||
|
|
||||||
|
device = text_logits.device
|
||||||
|
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
|
||||||
|
label_0, label_1 = labels.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
|
||||||
|
text_0_logits = text_0_logits[index, index]
|
||||||
|
text_1_logits = text_1_logits[index, index]
|
||||||
|
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
||||||
|
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
|
||||||
|
text_1_labels = text_0_labels + 1
|
||||||
|
|
||||||
|
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
|
||||||
|
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
|
||||||
|
|
||||||
|
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
|
||||||
|
|
||||||
|
# absolute_example_weight = 1 / num_per_prompt
|
||||||
|
# denominator = absolute_example_weight.sum()
|
||||||
|
# weight_per_example = absolute_example_weight / denominator
|
||||||
|
# text_loss *= weight_per_example
|
||||||
|
|
||||||
|
text_loss = text_loss.sum()
|
||||||
|
return text_loss
|
||||||
|
|
||||||
|
class RankingLoss(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, logits_per_image, num_images, labels, margin = 1.0):
|
||||||
|
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
||||||
|
label_list = [label for label in labels.split(num_images.tolist())]
|
||||||
|
# ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
|
||||||
|
|
||||||
|
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
|
||||||
|
padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
|
||||||
|
|
||||||
|
# regulized_logits = torch.log(torch.sigmoid(paired_logits))
|
||||||
|
|
||||||
|
diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
||||||
|
# diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
||||||
|
# diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
|
||||||
|
diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
|
||||||
|
mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
|
||||||
|
|
||||||
|
loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
class CoCaLoss(ClipLoss):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
caption_loss_weight,
|
||||||
|
clip_loss_weight,
|
||||||
|
pad_id=0, # pad_token for open_clip custom tokenizer
|
||||||
|
local_loss=False,
|
||||||
|
gather_with_grad=False,
|
||||||
|
cache_labels=False,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
use_horovod=False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
local_loss=local_loss,
|
||||||
|
gather_with_grad=gather_with_grad,
|
||||||
|
cache_labels=cache_labels,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
use_horovod=use_horovod
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_loss_weight = clip_loss_weight
|
||||||
|
self.caption_loss_weight = caption_loss_weight
|
||||||
|
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
||||||
|
|
||||||
|
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
||||||
|
clip_loss = super().forward(image_features, text_features, logit_scale)
|
||||||
|
clip_loss = self.clip_loss_weight * clip_loss
|
||||||
|
|
||||||
|
caption_loss = self.caption_loss(
|
||||||
|
logits.permute(0, 2, 1),
|
||||||
|
labels,
|
||||||
|
)
|
||||||
|
caption_loss = caption_loss * self.caption_loss_weight
|
||||||
|
|
||||||
|
if output_dict:
|
||||||
|
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
||||||
|
|
||||||
|
return clip_loss, caption_loss
|
||||||
|
|
||||||
|
|
||||||
|
class DistillClipLoss(ClipLoss):
|
||||||
|
|
||||||
|
def dist_loss(self, teacher_logits, student_logits):
|
||||||
|
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_features,
|
||||||
|
text_features,
|
||||||
|
logit_scale,
|
||||||
|
dist_image_features,
|
||||||
|
dist_text_features,
|
||||||
|
dist_logit_scale,
|
||||||
|
output_dict=False,
|
||||||
|
):
|
||||||
|
logits_per_image, logits_per_text = \
|
||||||
|
self.get_logits(image_features, text_features, logit_scale)
|
||||||
|
|
||||||
|
dist_logits_per_image, dist_logits_per_text = \
|
||||||
|
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
||||||
|
|
||||||
|
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
||||||
|
|
||||||
|
contrastive_loss = (
|
||||||
|
F.cross_entropy(logits_per_image, labels) +
|
||||||
|
F.cross_entropy(logits_per_text, labels)
|
||||||
|
) / 2
|
||||||
|
|
||||||
|
distill_loss = (
|
||||||
|
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
||||||
|
self.dist_loss(dist_logits_per_text, logits_per_text)
|
||||||
|
) / 2
|
||||||
|
|
||||||
|
if output_dict:
|
||||||
|
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
||||||
|
|
||||||
|
return contrastive_loss, distill_loss
|
||||||
461
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
Normal file
461
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
""" CLIP Model
|
||||||
|
|
||||||
|
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from .hf_model import HFTextEncoder
|
||||||
|
from .modified_resnet import ModifiedResNet
|
||||||
|
from .timm_model import TimmModel
|
||||||
|
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
||||||
|
from .utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPVisionCfg:
|
||||||
|
layers: Union[Tuple[int, int, int, int], int] = 12
|
||||||
|
width: int = 768
|
||||||
|
head_width: int = 64
|
||||||
|
mlp_ratio: float = 4.0
|
||||||
|
patch_size: int = 16
|
||||||
|
image_size: Union[Tuple[int, int], int] = 224
|
||||||
|
ls_init_value: Optional[float] = None # layer scale initial value
|
||||||
|
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
||||||
|
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
||||||
|
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
||||||
|
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
||||||
|
n_queries: int = 256 # n_queries for attentional pooler
|
||||||
|
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
||||||
|
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
||||||
|
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
||||||
|
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
||||||
|
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
||||||
|
timm_proj_bias: bool = False # enable bias final projection
|
||||||
|
timm_drop: float = 0. # head dropout
|
||||||
|
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
||||||
|
output_tokens: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPTextCfg:
|
||||||
|
context_length: int = 77
|
||||||
|
vocab_size: int = 49408
|
||||||
|
width: int = 512
|
||||||
|
heads: int = 8
|
||||||
|
layers: int = 12
|
||||||
|
ls_init_value: Optional[float] = None # layer scale initial value
|
||||||
|
hf_model_name: str = None
|
||||||
|
hf_tokenizer_name: str = None
|
||||||
|
hf_model_pretrained: bool = True
|
||||||
|
proj: str = 'mlp'
|
||||||
|
pooler_type: str = 'mean_pooler'
|
||||||
|
embed_cls: bool = False
|
||||||
|
pad_id: int = 0
|
||||||
|
output_tokens: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_cast_dtype(precision: str):
|
||||||
|
cast_dtype = None
|
||||||
|
if precision == 'bf16':
|
||||||
|
cast_dtype = torch.bfloat16
|
||||||
|
elif precision == 'fp16':
|
||||||
|
cast_dtype = torch.float16
|
||||||
|
return cast_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _build_vision_tower(
|
||||||
|
embed_dim: int,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None
|
||||||
|
):
|
||||||
|
if isinstance(vision_cfg, dict):
|
||||||
|
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
||||||
|
|
||||||
|
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
||||||
|
# memory efficient in recent PyTorch releases (>= 1.10).
|
||||||
|
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
||||||
|
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||||
|
|
||||||
|
if vision_cfg.timm_model_name:
|
||||||
|
visual = TimmModel(
|
||||||
|
vision_cfg.timm_model_name,
|
||||||
|
pretrained=vision_cfg.timm_model_pretrained,
|
||||||
|
pool=vision_cfg.timm_pool,
|
||||||
|
proj=vision_cfg.timm_proj,
|
||||||
|
proj_bias=vision_cfg.timm_proj_bias,
|
||||||
|
drop=vision_cfg.timm_drop,
|
||||||
|
drop_path=vision_cfg.timm_drop_path,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
image_size=vision_cfg.image_size,
|
||||||
|
)
|
||||||
|
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
||||||
|
elif isinstance(vision_cfg.layers, (tuple, list)):
|
||||||
|
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
||||||
|
visual = ModifiedResNet(
|
||||||
|
layers=vision_cfg.layers,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
heads=vision_heads,
|
||||||
|
image_size=vision_cfg.image_size,
|
||||||
|
width=vision_cfg.width,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vision_heads = vision_cfg.width // vision_cfg.head_width
|
||||||
|
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||||
|
visual = VisionTransformer(
|
||||||
|
image_size=vision_cfg.image_size,
|
||||||
|
patch_size=vision_cfg.patch_size,
|
||||||
|
width=vision_cfg.width,
|
||||||
|
layers=vision_cfg.layers,
|
||||||
|
heads=vision_heads,
|
||||||
|
mlp_ratio=vision_cfg.mlp_ratio,
|
||||||
|
ls_init_value=vision_cfg.ls_init_value,
|
||||||
|
patch_dropout=vision_cfg.patch_dropout,
|
||||||
|
input_patchnorm=vision_cfg.input_patchnorm,
|
||||||
|
global_average_pool=vision_cfg.global_average_pool,
|
||||||
|
attentional_pool=vision_cfg.attentional_pool,
|
||||||
|
n_queries=vision_cfg.n_queries,
|
||||||
|
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
||||||
|
output_tokens=vision_cfg.output_tokens,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return visual
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_tower(
|
||||||
|
embed_dim: int,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if isinstance(text_cfg, dict):
|
||||||
|
text_cfg = CLIPTextCfg(**text_cfg)
|
||||||
|
|
||||||
|
if text_cfg.hf_model_name:
|
||||||
|
text = HFTextEncoder(
|
||||||
|
text_cfg.hf_model_name,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
proj=text_cfg.proj,
|
||||||
|
pooler_type=text_cfg.pooler_type,
|
||||||
|
pretrained=text_cfg.hf_model_pretrained,
|
||||||
|
output_tokens=text_cfg.output_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
act_layer = QuickGELU if quick_gelu else nn.GELU
|
||||||
|
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
||||||
|
|
||||||
|
text = TextTransformer(
|
||||||
|
context_length=text_cfg.context_length,
|
||||||
|
vocab_size=text_cfg.vocab_size,
|
||||||
|
width=text_cfg.width,
|
||||||
|
heads=text_cfg.heads,
|
||||||
|
layers=text_cfg.layers,
|
||||||
|
ls_init_value=text_cfg.ls_init_value,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
embed_cls=text_cfg.embed_cls,
|
||||||
|
output_tokens=text_cfg.output_tokens,
|
||||||
|
pad_id=text_cfg.pad_id,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class CLIP(nn.Module):
|
||||||
|
output_dict: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
output_dict: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dict = output_dict
|
||||||
|
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
||||||
|
|
||||||
|
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
||||||
|
self.transformer = text.transformer
|
||||||
|
self.vocab_size = text.vocab_size
|
||||||
|
self.token_embedding = text.token_embedding
|
||||||
|
self.positional_embedding = text.positional_embedding
|
||||||
|
self.ln_final = text.ln_final
|
||||||
|
self.text_projection = text.text_projection
|
||||||
|
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
||||||
|
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
|
||||||
|
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||||
|
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||||
|
|
||||||
|
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||||
|
locked_layers = []
|
||||||
|
locked_layers.append(self.token_embedding)
|
||||||
|
self.positional_embedding.requires_grad = False
|
||||||
|
if unlocked_layers > 0:
|
||||||
|
locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
|
||||||
|
else:
|
||||||
|
locked_layers.append(self.transformer)
|
||||||
|
locked_layers.append(self.ln_final)
|
||||||
|
self.text_projection.requires_grad = False
|
||||||
|
|
||||||
|
# freeze layers
|
||||||
|
for module in locked_layers:
|
||||||
|
for n, p in module.named_parameters():
|
||||||
|
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.visual.set_grad_checkpointing(enable)
|
||||||
|
self.transformer.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def encode_image(self, image, normalize: bool = False):
|
||||||
|
features = self.visual(image)
|
||||||
|
return F.normalize(features, dim=-1) if normalize else features
|
||||||
|
|
||||||
|
def encode_text(self, text, normalize: bool = False):
|
||||||
|
cast_dtype = self.transformer.get_cast_dtype()
|
||||||
|
|
||||||
|
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||||
|
|
||||||
|
x = x + self.positional_embedding.to(cast_dtype)
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
||||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||||
|
return F.normalize(x, dim=-1) if normalize else x
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
image_features = self.encode_image(image, normalize=True)
|
||||||
|
text_features = self.encode_text(text, normalize=True)
|
||||||
|
if self.output_dict:
|
||||||
|
return {
|
||||||
|
"image_features": image_features,
|
||||||
|
"text_features": text_features,
|
||||||
|
"logit_scale": self.logit_scale.exp()
|
||||||
|
}
|
||||||
|
return image_features, text_features, self.logit_scale.exp()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTextCLIP(nn.Module):
|
||||||
|
output_dict: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_cfg: CLIPVisionCfg,
|
||||||
|
text_cfg: CLIPTextCfg,
|
||||||
|
quick_gelu: bool = False,
|
||||||
|
cast_dtype: Optional[torch.dtype] = None,
|
||||||
|
output_dict: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dict = output_dict
|
||||||
|
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
||||||
|
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
|
||||||
|
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
||||||
|
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
||||||
|
|
||||||
|
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
||||||
|
self.text.lock(unlocked_layers, freeze_layer_norm)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.visual.set_grad_checkpointing(enable)
|
||||||
|
self.text.set_grad_checkpointing(enable)
|
||||||
|
|
||||||
|
def encode_image(self, image, normalize: bool = False):
|
||||||
|
features = self.visual(image)
|
||||||
|
return F.normalize(features, dim=-1) if normalize else features
|
||||||
|
|
||||||
|
def encode_text(self, text, normalize: bool = False):
|
||||||
|
features = self.text(text)
|
||||||
|
return F.normalize(features, dim=-1) if normalize else features
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
image_features = self.encode_image(image, normalize=True)
|
||||||
|
text_features = self.encode_text(text, normalize=True)
|
||||||
|
if self.output_dict:
|
||||||
|
return {
|
||||||
|
"image_features": image_features,
|
||||||
|
"text_features": text_features,
|
||||||
|
"logit_scale": self.logit_scale.exp()
|
||||||
|
}
|
||||||
|
return image_features, text_features, self.logit_scale.exp()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
||||||
|
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
||||||
|
|
||||||
|
def _convert_weights(l):
|
||||||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||||
|
l.weight.data = l.weight.data.to(dtype)
|
||||||
|
if l.bias is not None:
|
||||||
|
l.bias.data = l.bias.data.to(dtype)
|
||||||
|
|
||||||
|
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||||
|
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||||
|
tensor = getattr(l, attr)
|
||||||
|
if tensor is not None:
|
||||||
|
tensor.data = tensor.data.to(dtype)
|
||||||
|
|
||||||
|
for name in ["text_projection", "proj"]:
|
||||||
|
if hasattr(l, name):
|
||||||
|
attr = getattr(l, name)
|
||||||
|
if attr is not None:
|
||||||
|
attr.data = attr.data.to(dtype)
|
||||||
|
|
||||||
|
model.apply(_convert_weights)
|
||||||
|
|
||||||
|
|
||||||
|
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
||||||
|
|
||||||
|
|
||||||
|
# used to maintain checkpoint compatibility
|
||||||
|
def convert_to_custom_text_state_dict(state_dict: dict):
|
||||||
|
if 'text_projection' in state_dict:
|
||||||
|
# old format state_dict, move text tower -> .text
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if any(k.startswith(p) for p in (
|
||||||
|
'text_projection',
|
||||||
|
'positional_embedding',
|
||||||
|
'token_embedding',
|
||||||
|
'transformer',
|
||||||
|
'ln_final',
|
||||||
|
)):
|
||||||
|
k = 'text.' + k
|
||||||
|
new_state_dict[k] = v
|
||||||
|
return new_state_dict
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_from_openai_state_dict(
|
||||||
|
state_dict: dict,
|
||||||
|
quick_gelu=True,
|
||||||
|
cast_dtype=torch.float16,
|
||||||
|
):
|
||||||
|
vit = "visual.proj" in state_dict
|
||||||
|
|
||||||
|
if vit:
|
||||||
|
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||||
|
vision_layers = len(
|
||||||
|
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||||
|
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||||
|
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||||
|
image_size = vision_patch_size * grid_size
|
||||||
|
else:
|
||||||
|
counts: list = [
|
||||||
|
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||||
|
vision_layers = tuple(counts)
|
||||||
|
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||||
|
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||||
|
vision_patch_size = None
|
||||||
|
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||||
|
image_size = output_width * 32
|
||||||
|
|
||||||
|
embed_dim = state_dict["text_projection"].shape[1]
|
||||||
|
context_length = state_dict["positional_embedding"].shape[0]
|
||||||
|
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||||
|
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||||
|
transformer_heads = transformer_width // 64
|
||||||
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||||
|
|
||||||
|
vision_cfg = CLIPVisionCfg(
|
||||||
|
layers=vision_layers,
|
||||||
|
width=vision_width,
|
||||||
|
patch_size=vision_patch_size,
|
||||||
|
image_size=image_size,
|
||||||
|
)
|
||||||
|
text_cfg = CLIPTextCfg(
|
||||||
|
context_length=context_length,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
width=transformer_width,
|
||||||
|
heads=transformer_heads,
|
||||||
|
layers=transformer_layers,
|
||||||
|
)
|
||||||
|
model = CLIP(
|
||||||
|
embed_dim,
|
||||||
|
vision_cfg=vision_cfg,
|
||||||
|
text_cfg=text_cfg,
|
||||||
|
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
||||||
|
cast_dtype=cast_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||||
|
state_dict.pop(key, None)
|
||||||
|
|
||||||
|
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
||||||
|
model.eval()
|
||||||
|
image_size = model.visual.image_size
|
||||||
|
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
||||||
|
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
||||||
|
model = torch.jit.trace_module(
|
||||||
|
model,
|
||||||
|
inputs=dict(
|
||||||
|
forward=(example_images, example_text),
|
||||||
|
encode_text=(example_text,),
|
||||||
|
encode_image=(example_images,)
|
||||||
|
))
|
||||||
|
model.visual.image_size = image_size
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
||||||
|
# Rescale the grid of position embeddings when loading from state_dict
|
||||||
|
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
||||||
|
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
||||||
|
return
|
||||||
|
grid_size = to_2tuple(model.visual.grid_size)
|
||||||
|
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
||||||
|
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
||||||
|
if new_seq_len == old_pos_embed.shape[0]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if extra_tokens:
|
||||||
|
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
||||||
|
else:
|
||||||
|
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
||||||
|
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
||||||
|
|
||||||
|
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
||||||
|
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
pos_emb_img = F.interpolate(
|
||||||
|
pos_emb_img,
|
||||||
|
size=grid_size,
|
||||||
|
mode=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
||||||
|
if pos_emb_tok is not None:
|
||||||
|
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
||||||
|
else:
|
||||||
|
new_pos_embed = pos_emb_img
|
||||||
|
state_dict['visual.positional_embedding'] = new_pos_embed
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"embed_dim": 1024,
|
||||||
|
"vision_cfg": {
|
||||||
|
"image_size": 224,
|
||||||
|
"layers": 32,
|
||||||
|
"width": 1280,
|
||||||
|
"head_width": 80,
|
||||||
|
"patch_size": 14
|
||||||
|
},
|
||||||
|
"text_cfg": {
|
||||||
|
"context_length": 77,
|
||||||
|
"vocab_size": 49408,
|
||||||
|
"width": 1024,
|
||||||
|
"heads": 16,
|
||||||
|
"layers": 24
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .utils import freeze_batch_norm_2d
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
self.act1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
self.act2 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.act3 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.downsample = None
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||||
|
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||||
|
self.downsample = nn.Sequential(OrderedDict([
|
||||||
|
("-1", nn.AvgPool2d(stride)),
|
||||||
|
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||||
|
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.act1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.act2(self.bn2(self.conv2(out)))
|
||||||
|
out = self.avgpool(out)
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.act3(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool2d(nn.Module):
|
||||||
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||||
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||||
|
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||||
|
x, _ = F.multi_head_attention_forward(
|
||||||
|
query=x, key=x, value=x,
|
||||||
|
embed_dim_to_check=x.shape[-1],
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
in_proj_weight=None,
|
||||||
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||||
|
bias_k=None,
|
||||||
|
bias_v=None,
|
||||||
|
add_zero_attn=False,
|
||||||
|
dropout_p=0.,
|
||||||
|
out_proj_weight=self.c_proj.weight,
|
||||||
|
out_proj_bias=self.c_proj.bias,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
training=self.training,
|
||||||
|
need_weights=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return x[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedResNet(nn.Module):
|
||||||
|
"""
|
||||||
|
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||||
|
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||||
|
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||||
|
- The final pooling layer is a QKV attention instead of an average pool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
# the 3-layer stem
|
||||||
|
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||||
|
self.act1 = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||||
|
self.act2 = nn.ReLU(inplace=True)
|
||||||
|
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(width)
|
||||||
|
self.act3 = nn.ReLU(inplace=True)
|
||||||
|
self.avgpool = nn.AvgPool2d(2)
|
||||||
|
|
||||||
|
# residual layers
|
||||||
|
self._inplanes = width # this is a *mutable* variable used during construction
|
||||||
|
self.layer1 = self._make_layer(width, layers[0])
|
||||||
|
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||||
|
|
||||||
|
embed_dim = width * 32 # the ResNet feature dimension
|
||||||
|
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
||||||
|
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def _make_layer(self, planes, blocks, stride=1):
|
||||||
|
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||||
|
|
||||||
|
self._inplanes = planes * Bottleneck.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(Bottleneck(self._inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
if self.attnpool is not None:
|
||||||
|
std = self.attnpool.c_proj.in_features ** -0.5
|
||||||
|
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
||||||
|
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
||||||
|
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
||||||
|
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
||||||
|
|
||||||
|
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
||||||
|
for name, param in resnet_block.named_parameters():
|
||||||
|
if name.endswith("bn3.weight"):
|
||||||
|
nn.init.zeros_(param)
|
||||||
|
|
||||||
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if freeze_bn_stats:
|
||||||
|
freeze_batch_norm_2d(self)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
# FIXME support for non-transformer
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stem(self, x):
|
||||||
|
x = self.act1(self.bn1(self.conv1(x)))
|
||||||
|
x = self.act2(self.bn2(self.conv2(x)))
|
||||||
|
x = self.act3(self.bn3(self.conv3(x)))
|
||||||
|
x = self.avgpool(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.attnpool(x)
|
||||||
|
|
||||||
|
return x
|
||||||
144
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
Normal file
144
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
""" OpenAI pretrained model functions
|
||||||
|
|
||||||
|
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
||||||
|
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
||||||
|
|
||||||
|
__all__ = ["list_openai_models", "load_openai_model"]
|
||||||
|
|
||||||
|
|
||||||
|
def list_openai_models() -> List[str]:
|
||||||
|
"""Returns the names of available CLIP models"""
|
||||||
|
return list_pretrained_models_by_tag('openai')
|
||||||
|
|
||||||
|
|
||||||
|
def load_openai_model(
|
||||||
|
name: str,
|
||||||
|
precision: Optional[str] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
jit: bool = True,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Load a CLIP model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||||
|
precision: str
|
||||||
|
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
||||||
|
device : Union[str, torch.device]
|
||||||
|
The device to put the loaded model
|
||||||
|
jit : bool
|
||||||
|
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||||
|
cache_dir : Optional[str]
|
||||||
|
The directory to cache the downloaded model weights
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : torch.nn.Module
|
||||||
|
The CLIP model
|
||||||
|
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||||
|
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if precision is None:
|
||||||
|
precision = 'fp32' if device == 'cpu' else 'fp16'
|
||||||
|
|
||||||
|
if get_pretrained_url(name, 'openai'):
|
||||||
|
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
||||||
|
elif os.path.isfile(name):
|
||||||
|
model_path = name
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# loading JIT archive
|
||||||
|
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||||
|
state_dict = None
|
||||||
|
except RuntimeError:
|
||||||
|
# loading saved state dict
|
||||||
|
if jit:
|
||||||
|
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||||
|
jit = False
|
||||||
|
state_dict = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
|
if not jit:
|
||||||
|
# Build a non-jit model from the OpenAI jitted model state dict
|
||||||
|
cast_dtype = get_cast_dtype(precision)
|
||||||
|
try:
|
||||||
|
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
||||||
|
except KeyError:
|
||||||
|
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
||||||
|
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
||||||
|
|
||||||
|
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
||||||
|
model = model.to(device)
|
||||||
|
if precision.startswith('amp') or precision == 'fp32':
|
||||||
|
model.float()
|
||||||
|
elif precision == 'bf16':
|
||||||
|
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# patch the device names
|
||||||
|
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||||
|
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||||
|
|
||||||
|
def patch_device(module):
|
||||||
|
try:
|
||||||
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||||
|
except RuntimeError:
|
||||||
|
graphs = []
|
||||||
|
|
||||||
|
if hasattr(module, "forward1"):
|
||||||
|
graphs.append(module.forward1.graph)
|
||||||
|
|
||||||
|
for graph in graphs:
|
||||||
|
for node in graph.findAllNodes("prim::Constant"):
|
||||||
|
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||||
|
node.copyAttributes(device_node)
|
||||||
|
|
||||||
|
model.apply(patch_device)
|
||||||
|
patch_device(model.encode_image)
|
||||||
|
patch_device(model.encode_text)
|
||||||
|
|
||||||
|
# patch dtype to float32 (typically for CPU)
|
||||||
|
if precision == 'fp32':
|
||||||
|
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||||
|
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||||
|
float_node = float_input.node()
|
||||||
|
|
||||||
|
def patch_float(module):
|
||||||
|
try:
|
||||||
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||||
|
except RuntimeError:
|
||||||
|
graphs = []
|
||||||
|
|
||||||
|
if hasattr(module, "forward1"):
|
||||||
|
graphs.append(module.forward1.graph)
|
||||||
|
|
||||||
|
for graph in graphs:
|
||||||
|
for node in graph.findAllNodes("aten::to"):
|
||||||
|
inputs = list(node.inputs())
|
||||||
|
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||||
|
if inputs[i].node()["value"] == 5:
|
||||||
|
inputs[i].node().copyAttributes(float_node)
|
||||||
|
|
||||||
|
model.apply(patch_float)
|
||||||
|
patch_float(model.encode_image)
|
||||||
|
patch_float(model.encode_text)
|
||||||
|
model.float()
|
||||||
|
|
||||||
|
# ensure image_size attr available at consistent location for both jit and non-jit
|
||||||
|
model.visual.image_size = model.input_resolution.item()
|
||||||
|
return model
|
||||||
376
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
Normal file
376
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .version import __version__
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
|
||||||
|
_has_hf_hub = True
|
||||||
|
except ImportError:
|
||||||
|
hf_hub_download = None
|
||||||
|
_has_hf_hub = False
|
||||||
|
|
||||||
|
|
||||||
|
def _pcfg(url='', hf_hub='', mean=None, std=None):
|
||||||
|
return dict(
|
||||||
|
url=url,
|
||||||
|
hf_hub=hf_hub,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_RN50 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
||||||
|
cc12m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50_quickgelu = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
||||||
|
cc12m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN101 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN101_quickgelu = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
||||||
|
yfcc15m=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50x4 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50x16 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_RN50x64 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB32 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
||||||
|
laion2b_e16=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
||||||
|
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB32_quickgelu = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB16 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
||||||
|
# laion400m_32k=_pcfg(
|
||||||
|
# url="",
|
||||||
|
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
# laion400m_64k=_pcfg(
|
||||||
|
# url="",
|
||||||
|
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITB16_PLUS_240 = dict(
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITL14 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
||||||
|
laion400m_e31=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
||||||
|
laion400m_e32=_pcfg(
|
||||||
|
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
||||||
|
laion2b_s32b_b82k=_pcfg(
|
||||||
|
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITL14_336 = dict(
|
||||||
|
openai=_pcfg(
|
||||||
|
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITH14 = dict(
|
||||||
|
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITg14 = dict(
|
||||||
|
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
||||||
|
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VITbigG14 = dict(
|
||||||
|
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_robertaViTB32 = dict(
|
||||||
|
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_xlmRobertaBaseViTB32 = dict(
|
||||||
|
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_xlmRobertaLargeFrozenViTH14 = dict(
|
||||||
|
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_base = dict(
|
||||||
|
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_base_w = dict(
|
||||||
|
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
|
||||||
|
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
|
||||||
|
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_base_w_320 = dict(
|
||||||
|
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
|
||||||
|
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_large_d = dict(
|
||||||
|
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_large_d_320 = dict(
|
||||||
|
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
|
||||||
|
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_convnext_xxlarge = dict(
|
||||||
|
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
|
||||||
|
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
|
||||||
|
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
|
||||||
|
)
|
||||||
|
|
||||||
|
_coca_VITB32 = dict(
|
||||||
|
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
|
||||||
|
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
|
||||||
|
)
|
||||||
|
|
||||||
|
_coca_VITL14 = dict(
|
||||||
|
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
|
||||||
|
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PRETRAINED = {
|
||||||
|
"RN50": _RN50,
|
||||||
|
"RN50-quickgelu": _RN50_quickgelu,
|
||||||
|
"RN101": _RN101,
|
||||||
|
"RN101-quickgelu": _RN101_quickgelu,
|
||||||
|
"RN50x4": _RN50x4,
|
||||||
|
"RN50x16": _RN50x16,
|
||||||
|
"RN50x64": _RN50x64,
|
||||||
|
"ViT-B-32": _VITB32,
|
||||||
|
"ViT-B-32-quickgelu": _VITB32_quickgelu,
|
||||||
|
"ViT-B-16": _VITB16,
|
||||||
|
"ViT-B-16-plus-240": _VITB16_PLUS_240,
|
||||||
|
"ViT-L-14": _VITL14,
|
||||||
|
"ViT-L-14-336": _VITL14_336,
|
||||||
|
"ViT-H-14": _VITH14,
|
||||||
|
"ViT-g-14": _VITg14,
|
||||||
|
"ViT-bigG-14": _VITbigG14,
|
||||||
|
"roberta-ViT-B-32": _robertaViTB32,
|
||||||
|
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
|
||||||
|
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
|
||||||
|
"convnext_base": _convnext_base,
|
||||||
|
"convnext_base_w": _convnext_base_w,
|
||||||
|
"convnext_base_w_320": _convnext_base_w_320,
|
||||||
|
"convnext_large_d": _convnext_large_d,
|
||||||
|
"convnext_large_d_320": _convnext_large_d_320,
|
||||||
|
"convnext_xxlarge": _convnext_xxlarge,
|
||||||
|
"coca_ViT-B-32": _coca_VITB32,
|
||||||
|
"coca_ViT-L-14": _coca_VITL14,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_tag(tag: str):
|
||||||
|
# normalize pretrained tags
|
||||||
|
return tag.lower().replace('-', '_')
|
||||||
|
|
||||||
|
|
||||||
|
def list_pretrained(as_str: bool = False):
|
||||||
|
""" returns list of pretrained models
|
||||||
|
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
||||||
|
"""
|
||||||
|
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
||||||
|
|
||||||
|
|
||||||
|
def list_pretrained_models_by_tag(tag: str):
|
||||||
|
""" return all models having the specified pretrain tag """
|
||||||
|
models = []
|
||||||
|
tag = _clean_tag(tag)
|
||||||
|
for k in _PRETRAINED.keys():
|
||||||
|
if tag in _PRETRAINED[k]:
|
||||||
|
models.append(k)
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def list_pretrained_tags_by_model(model: str):
|
||||||
|
""" return all pretrain tags for the specified model architecture """
|
||||||
|
tags = []
|
||||||
|
if model in _PRETRAINED:
|
||||||
|
tags.extend(_PRETRAINED[model].keys())
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def is_pretrained_cfg(model: str, tag: str):
|
||||||
|
if model not in _PRETRAINED:
|
||||||
|
return False
|
||||||
|
return _clean_tag(tag) in _PRETRAINED[model]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_cfg(model: str, tag: str):
|
||||||
|
if model not in _PRETRAINED:
|
||||||
|
return {}
|
||||||
|
model_pretrained = _PRETRAINED[model]
|
||||||
|
return model_pretrained.get(_clean_tag(tag), {})
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_url(model: str, tag: str):
|
||||||
|
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
||||||
|
return cfg.get('url', '')
|
||||||
|
|
||||||
|
|
||||||
|
def download_pretrained_from_url(
|
||||||
|
url: str,
|
||||||
|
cache_dir: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.expanduser("~/.cache/clip")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
|
||||||
|
if 'openaipublic' in url:
|
||||||
|
expected_sha256 = url.split("/")[-2]
|
||||||
|
elif 'mlfoundations' in url:
|
||||||
|
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
||||||
|
else:
|
||||||
|
expected_sha256 = ''
|
||||||
|
|
||||||
|
download_target = os.path.join(cache_dir, filename)
|
||||||
|
|
||||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
if os.path.isfile(download_target):
|
||||||
|
if expected_sha256:
|
||||||
|
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
||||||
|
return download_target
|
||||||
|
else:
|
||||||
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
else:
|
||||||
|
return download_target
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
|
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
||||||
|
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||||
|
|
||||||
|
return download_target
|
||||||
|
|
||||||
|
|
||||||
|
def has_hf_hub(necessary=False):
|
||||||
|
if not _has_hf_hub and necessary:
|
||||||
|
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||||
|
raise RuntimeError(
|
||||||
|
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||||
|
return _has_hf_hub
|
||||||
|
|
||||||
|
|
||||||
|
def download_pretrained_from_hf(
|
||||||
|
model_id: str,
|
||||||
|
filename: str = 'open_clip_pytorch_model.bin',
|
||||||
|
revision=None,
|
||||||
|
cache_dir: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
has_hf_hub(True)
|
||||||
|
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
|
def download_pretrained(
|
||||||
|
cfg: Dict,
|
||||||
|
force_hf_hub: bool = False,
|
||||||
|
cache_dir: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
target = ''
|
||||||
|
if not cfg:
|
||||||
|
return target
|
||||||
|
|
||||||
|
download_url = cfg.get('url', '')
|
||||||
|
download_hf_hub = cfg.get('hf_hub', '')
|
||||||
|
if download_hf_hub and force_hf_hub:
|
||||||
|
# use HF hub even if url exists
|
||||||
|
download_url = ''
|
||||||
|
|
||||||
|
if download_url:
|
||||||
|
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
||||||
|
elif download_hf_hub:
|
||||||
|
has_hf_hub(True)
|
||||||
|
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
||||||
|
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
||||||
|
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
||||||
|
model_id, filename = os.path.split(download_hf_hub)
|
||||||
|
if filename:
|
||||||
|
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
||||||
|
else:
|
||||||
|
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
||||||
|
|
||||||
|
return target
|
||||||
@@ -0,0 +1,243 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import (
|
||||||
|
create_repo,
|
||||||
|
get_hf_file_metadata,
|
||||||
|
hf_hub_download,
|
||||||
|
hf_hub_url,
|
||||||
|
repo_type_and_id_from_hf_id,
|
||||||
|
upload_folder,
|
||||||
|
)
|
||||||
|
from huggingface_hub.utils import EntryNotFoundError
|
||||||
|
_has_hf_hub = True
|
||||||
|
except ImportError:
|
||||||
|
_has_hf_hub = False
|
||||||
|
|
||||||
|
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
|
||||||
|
from .tokenizer import HFTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def save_config_for_hf(
|
||||||
|
model,
|
||||||
|
config_path: str,
|
||||||
|
model_config: Optional[dict]
|
||||||
|
):
|
||||||
|
preprocess_cfg = {
|
||||||
|
'mean': model.visual.image_mean,
|
||||||
|
'std': model.visual.image_std,
|
||||||
|
}
|
||||||
|
hf_config = {
|
||||||
|
'model_cfg': model_config,
|
||||||
|
'preprocess_cfg': preprocess_cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
with config_path.open('w') as f:
|
||||||
|
json.dump(hf_config, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def save_for_hf(
|
||||||
|
model,
|
||||||
|
tokenizer: HFTokenizer,
|
||||||
|
model_config: dict,
|
||||||
|
save_directory: str,
|
||||||
|
weights_filename='open_clip_pytorch_model.bin',
|
||||||
|
config_filename='open_clip_config.json',
|
||||||
|
):
|
||||||
|
save_directory = Path(save_directory)
|
||||||
|
save_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
weights_path = save_directory / weights_filename
|
||||||
|
torch.save(model.state_dict(), weights_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
config_path = save_directory / config_filename
|
||||||
|
save_config_for_hf(model, config_path, model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
|
def push_to_hf_hub(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
model_config: Optional[dict],
|
||||||
|
repo_id: str,
|
||||||
|
commit_message: str = 'Add model',
|
||||||
|
token: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
private: bool = False,
|
||||||
|
create_pr: bool = False,
|
||||||
|
model_card: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
if not isinstance(tokenizer, HFTokenizer):
|
||||||
|
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
|
||||||
|
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
|
||||||
|
|
||||||
|
# Create repo if it doesn't exist yet
|
||||||
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||||
|
|
||||||
|
# Infer complete repo_id from repo_url
|
||||||
|
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||||
|
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||||
|
repo_id = f"{repo_owner}/{repo_name}"
|
||||||
|
|
||||||
|
# Check if README file already exist in repo
|
||||||
|
try:
|
||||||
|
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||||
|
has_readme = True
|
||||||
|
except EntryNotFoundError:
|
||||||
|
has_readme = False
|
||||||
|
|
||||||
|
# Dump model and push to Hub
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
# Save model weights and config.
|
||||||
|
save_for_hf(
|
||||||
|
model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
save_directory=tmpdir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add readme if it does not exist
|
||||||
|
if not has_readme:
|
||||||
|
model_card = model_card or {}
|
||||||
|
model_name = repo_id.split('/')[-1]
|
||||||
|
readme_path = Path(tmpdir) / "README.md"
|
||||||
|
readme_text = generate_readme(model_card, model_name)
|
||||||
|
readme_path.write_text(readme_text)
|
||||||
|
|
||||||
|
# Upload model and return
|
||||||
|
return upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=tmpdir,
|
||||||
|
revision=revision,
|
||||||
|
create_pr=create_pr,
|
||||||
|
commit_message=commit_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def push_pretrained_to_hf_hub(
|
||||||
|
model_name,
|
||||||
|
pretrained: str,
|
||||||
|
repo_id: str,
|
||||||
|
image_mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
image_std: Optional[Tuple[float, ...]] = None,
|
||||||
|
commit_message: str = 'Add model',
|
||||||
|
token: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
private: bool = False,
|
||||||
|
create_pr: bool = False,
|
||||||
|
model_card: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
model, preprocess_eval = create_model_from_pretrained(
|
||||||
|
model_name,
|
||||||
|
pretrained=pretrained,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = get_model_config(model_name)
|
||||||
|
assert model_config
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(model_name)
|
||||||
|
|
||||||
|
push_to_hf_hub(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
repo_id=repo_id,
|
||||||
|
commit_message=commit_message,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
private=private,
|
||||||
|
create_pr=create_pr,
|
||||||
|
model_card=model_card,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_readme(model_card: dict, model_name: str):
|
||||||
|
readme_text = "---\n"
|
||||||
|
readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
|
||||||
|
readme_text += "library_tag: open_clip\n"
|
||||||
|
readme_text += f"license: {model_card.get('license', 'mit')}\n"
|
||||||
|
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||||
|
readme_text += 'datasets:\n'
|
||||||
|
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||||
|
readme_text += "---\n"
|
||||||
|
readme_text += f"# Model card for {model_name}\n"
|
||||||
|
if 'description' in model_card:
|
||||||
|
readme_text += f"\n{model_card['description']}\n"
|
||||||
|
if 'details' in model_card:
|
||||||
|
readme_text += f"\n## Model Details\n"
|
||||||
|
for k, v in model_card['details'].items():
|
||||||
|
if isinstance(v, (list, tuple)):
|
||||||
|
readme_text += f"- **{k}:**\n"
|
||||||
|
for vi in v:
|
||||||
|
readme_text += f" - {vi}\n"
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
readme_text += f"- **{k}:**\n"
|
||||||
|
for ki, vi in v.items():
|
||||||
|
readme_text += f" - {ki}: {vi}\n"
|
||||||
|
else:
|
||||||
|
readme_text += f"- **{k}:** {v}\n"
|
||||||
|
if 'usage' in model_card:
|
||||||
|
readme_text += f"\n## Model Usage\n"
|
||||||
|
readme_text += model_card['usage']
|
||||||
|
readme_text += '\n'
|
||||||
|
|
||||||
|
if 'comparison' in model_card:
|
||||||
|
readme_text += f"\n## Model Comparison\n"
|
||||||
|
readme_text += model_card['comparison']
|
||||||
|
readme_text += '\n'
|
||||||
|
|
||||||
|
if 'citation' in model_card:
|
||||||
|
readme_text += f"\n## Citation\n"
|
||||||
|
if not isinstance(model_card['citation'], (list, tuple)):
|
||||||
|
citations = [model_card['citation']]
|
||||||
|
else:
|
||||||
|
citations = model_card['citation']
|
||||||
|
for c in citations:
|
||||||
|
readme_text += f"```bibtex\n{c}\n```\n"
|
||||||
|
|
||||||
|
return readme_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, help="Name of the model to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained", type=str,
|
||||||
|
help="Use a pretrained CLIP model weights with the specified tag or file path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id", type=str,
|
||||||
|
help="Destination HF Hub repo-id ie 'organization/model_id'.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
|
help='Override default image mean value of dataset')
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
help='Override default image std deviation of of dataset')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
|
||||||
|
|
||||||
|
# FIXME add support to pass model_card json / template from file via cmd line
|
||||||
|
|
||||||
|
push_pretrained_to_hf_hub(
|
||||||
|
args.model,
|
||||||
|
args.pretrained,
|
||||||
|
args.repo_id,
|
||||||
|
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
|
||||||
|
image_std=args.image_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'{args.model} saved.')
|
||||||
127
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
Normal file
127
diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
""" timm model adapter
|
||||||
|
|
||||||
|
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import timm
|
||||||
|
from timm.models.layers import Mlp, to_2tuple
|
||||||
|
try:
|
||||||
|
# old timm imports < 0.8.1
|
||||||
|
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
||||||
|
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
||||||
|
except ImportError:
|
||||||
|
# new timm imports >= 0.8.1
|
||||||
|
from timm.layers import RotAttentionPool2d
|
||||||
|
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
||||||
|
except ImportError:
|
||||||
|
timm = None
|
||||||
|
|
||||||
|
from .utils import freeze_batch_norm_2d
|
||||||
|
|
||||||
|
|
||||||
|
class TimmModel(nn.Module):
|
||||||
|
""" timm model adapter
|
||||||
|
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
embed_dim,
|
||||||
|
image_size=224,
|
||||||
|
pool='avg',
|
||||||
|
proj='linear',
|
||||||
|
proj_bias=False,
|
||||||
|
drop=0.,
|
||||||
|
drop_path=None,
|
||||||
|
pretrained=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if timm is None:
|
||||||
|
raise RuntimeError("Please `pip install timm` to use timm models.")
|
||||||
|
|
||||||
|
self.image_size = to_2tuple(image_size)
|
||||||
|
timm_kwargs = {}
|
||||||
|
if drop_path is not None:
|
||||||
|
timm_kwargs['drop_path_rate'] = drop_path
|
||||||
|
self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
|
||||||
|
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
||||||
|
feature_ndim = 1 if not feat_size else 2
|
||||||
|
if pool in ('abs_attn', 'rot_attn'):
|
||||||
|
assert feature_ndim == 2
|
||||||
|
# if attn pooling used, remove both classifier and default pool
|
||||||
|
self.trunk.reset_classifier(0, global_pool='')
|
||||||
|
else:
|
||||||
|
# reset global pool if pool config set, otherwise leave as network default
|
||||||
|
reset_kwargs = dict(global_pool=pool) if pool else {}
|
||||||
|
self.trunk.reset_classifier(0, **reset_kwargs)
|
||||||
|
prev_chs = self.trunk.num_features
|
||||||
|
|
||||||
|
head_layers = OrderedDict()
|
||||||
|
if pool == 'abs_attn':
|
||||||
|
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
||||||
|
prev_chs = embed_dim
|
||||||
|
elif pool == 'rot_attn':
|
||||||
|
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
||||||
|
prev_chs = embed_dim
|
||||||
|
else:
|
||||||
|
assert proj, 'projection layer needed if non-attention pooling is used.'
|
||||||
|
|
||||||
|
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
||||||
|
if proj == 'linear':
|
||||||
|
head_layers['drop'] = nn.Dropout(drop)
|
||||||
|
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
||||||
|
elif proj == 'mlp':
|
||||||
|
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
|
||||||
|
|
||||||
|
self.head = nn.Sequential(head_layers)
|
||||||
|
|
||||||
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
""" lock modules
|
||||||
|
Args:
|
||||||
|
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
||||||
|
"""
|
||||||
|
if not unlocked_groups:
|
||||||
|
# lock full model
|
||||||
|
for param in self.trunk.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if freeze_bn_stats:
|
||||||
|
freeze_batch_norm_2d(self.trunk)
|
||||||
|
else:
|
||||||
|
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
||||||
|
try:
|
||||||
|
# FIXME import here until API stable and in an official release
|
||||||
|
from timm.models.helpers import group_parameters, group_modules
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
||||||
|
matcher = self.trunk.group_matcher()
|
||||||
|
gparams = group_parameters(self.trunk, matcher)
|
||||||
|
max_layer_id = max(gparams.keys())
|
||||||
|
max_layer_id = max_layer_id - unlocked_groups
|
||||||
|
for group_idx in range(max_layer_id + 1):
|
||||||
|
group = gparams[group_idx]
|
||||||
|
for param in group:
|
||||||
|
self.trunk.get_parameter(param).requires_grad = False
|
||||||
|
if freeze_bn_stats:
|
||||||
|
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
||||||
|
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
||||||
|
freeze_batch_norm_2d(self.trunk, gmodules)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
try:
|
||||||
|
self.trunk.set_grad_checkpointing(enable)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.trunk(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
211
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
Normal file
211
diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
""" CLIP tokenizer
|
||||||
|
|
||||||
|
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
||||||
|
"""
|
||||||
|
import gzip
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import ftfy
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# https://stackoverflow.com/q/62691279
|
||||||
|
import os
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def default_bpe():
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||||
|
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8+n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
|
def get_pairs(word):
|
||||||
|
"""Return set of symbol pairs in a word.
|
||||||
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||||
|
"""
|
||||||
|
pairs = set()
|
||||||
|
prev_char = word[0]
|
||||||
|
for char in word[1:]:
|
||||||
|
pairs.add((prev_char, char))
|
||||||
|
prev_char = char
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def basic_clean(text):
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_clean(text):
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTokenizer(object):
|
||||||
|
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||||
|
merges = merges[1:49152-256-2+1]
|
||||||
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
|
vocab = list(bytes_to_unicode().values())
|
||||||
|
vocab = vocab + [v+'</w>' for v in vocab]
|
||||||
|
for merge in merges:
|
||||||
|
vocab.append(''.join(merge))
|
||||||
|
if not special_tokens:
|
||||||
|
special_tokens = ['<start_of_text>', '<end_of_text>']
|
||||||
|
else:
|
||||||
|
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
||||||
|
vocab.extend(special_tokens)
|
||||||
|
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
|
self.cache = {t:t for t in special_tokens}
|
||||||
|
special = "|".join(special_tokens)
|
||||||
|
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||||
|
|
||||||
|
self.vocab_size = len(self.encoder)
|
||||||
|
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
||||||
|
|
||||||
|
def bpe(self, token):
|
||||||
|
if token in self.cache:
|
||||||
|
return self.cache[token]
|
||||||
|
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
return token+'</w>'
|
||||||
|
|
||||||
|
while True:
|
||||||
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
first, second = bigram
|
||||||
|
new_word = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
try:
|
||||||
|
j = word.index(first, i)
|
||||||
|
new_word.extend(word[i:j])
|
||||||
|
i = j
|
||||||
|
except:
|
||||||
|
new_word.extend(word[i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||||
|
new_word.append(first+second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
new_word = tuple(new_word)
|
||||||
|
word = new_word
|
||||||
|
if len(word) == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
word = ' '.join(word)
|
||||||
|
self.cache[token] = word
|
||||||
|
return word
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
bpe_tokens = []
|
||||||
|
text = whitespace_clean(basic_clean(text)).lower()
|
||||||
|
for token in re.findall(self.pat, text):
|
||||||
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Returns the tokenized representation of given input string(s)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts : Union[str, List[str]]
|
||||||
|
An input string or a list of input strings to tokenize
|
||||||
|
context_length : int
|
||||||
|
The context length to use; all CLIP models use 77 as the context length
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
sot_token = self.encoder["<start_of_text>"]
|
||||||
|
eot_token = self.encoder["<end_of_text>"]
|
||||||
|
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
||||||
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||||
|
|
||||||
|
for i, tokens in enumerate(all_tokens):
|
||||||
|
if len(tokens) > context_length:
|
||||||
|
tokens = tokens[:context_length] # Truncate
|
||||||
|
tokens[-1] = eot_token
|
||||||
|
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HFTokenizer:
|
||||||
|
"""HuggingFace tokenizer wrapper"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_name: str):
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
|
||||||
|
def save_pretrained(self, dest):
|
||||||
|
self.tokenizer.save_pretrained(dest)
|
||||||
|
|
||||||
|
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
||||||
|
# same cleaning as for default tokenizer, except lowercasing
|
||||||
|
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
||||||
|
input_ids = self.tokenizer(
|
||||||
|
texts,
|
||||||
|
return_tensors='pt',
|
||||||
|
max_length=context_length,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
).input_ids
|
||||||
|
return input_ids
|
||||||
216
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
Normal file
216
diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
|
from functools import partial
|
||||||
|
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
||||||
|
CenterCrop
|
||||||
|
|
||||||
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AugmentationCfg:
|
||||||
|
scale: Tuple[float, float] = (0.9, 1.0)
|
||||||
|
ratio: Optional[Tuple[float, float]] = None
|
||||||
|
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
||||||
|
interpolation: Optional[str] = None
|
||||||
|
re_prob: Optional[float] = None
|
||||||
|
re_count: Optional[int] = None
|
||||||
|
use_timm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeMaxSize(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(max_size, int):
|
||||||
|
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
||||||
|
self.max_size = max_size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
self.fn = min if fn == 'min' else min
|
||||||
|
self.fill = fill
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
height, width = img.shape[1:]
|
||||||
|
else:
|
||||||
|
width, height = img.size
|
||||||
|
scale = self.max_size / float(max(height, width))
|
||||||
|
if scale != 1.0:
|
||||||
|
new_size = tuple(round(dim * scale) for dim in (height, width))
|
||||||
|
img = F.resize(img, new_size, self.interpolation)
|
||||||
|
pad_h = self.max_size - new_size[0]
|
||||||
|
pad_w = self.max_size - new_size[1]
|
||||||
|
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_rgb_or_rgba(image):
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
return image
|
||||||
|
else:
|
||||||
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
# def transform_and_split(merged, transform_fn, normalize_fn):
|
||||||
|
# transformed = transform_fn(merged)
|
||||||
|
# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
|
||||||
|
|
||||||
|
# # crop_img = _convert_to_rgb(crop_img)
|
||||||
|
# crop_img = normalize_fn(ToTensor()(crop_img))
|
||||||
|
# return crop_img, crop_label
|
||||||
|
|
||||||
|
class MaskAwareNormalize(nn.Module):
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
super().__init__()
|
||||||
|
self.normalize = Normalize(mean=mean, std=std)
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
if tensor.shape[0] == 4:
|
||||||
|
return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
|
||||||
|
else:
|
||||||
|
return self.normalize(tensor)
|
||||||
|
|
||||||
|
def image_transform(
|
||||||
|
image_size: int,
|
||||||
|
is_train: bool,
|
||||||
|
mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
std: Optional[Tuple[float, ...]] = None,
|
||||||
|
resize_longest_max: bool = False,
|
||||||
|
fill_color: int = 0,
|
||||||
|
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||||
|
):
|
||||||
|
mean = mean or OPENAI_DATASET_MEAN
|
||||||
|
if not isinstance(mean, (list, tuple)):
|
||||||
|
mean = (mean,) * 3
|
||||||
|
|
||||||
|
std = std or OPENAI_DATASET_STD
|
||||||
|
if not isinstance(std, (list, tuple)):
|
||||||
|
std = (std,) * 3
|
||||||
|
|
||||||
|
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
||||||
|
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
||||||
|
image_size = image_size[0]
|
||||||
|
|
||||||
|
if isinstance(aug_cfg, dict):
|
||||||
|
aug_cfg = AugmentationCfg(**aug_cfg)
|
||||||
|
else:
|
||||||
|
aug_cfg = aug_cfg or AugmentationCfg()
|
||||||
|
normalize = MaskAwareNormalize(mean=mean, std=std)
|
||||||
|
if is_train:
|
||||||
|
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
||||||
|
use_timm = aug_cfg_dict.pop('use_timm', False)
|
||||||
|
if use_timm:
|
||||||
|
assert False, "not tested for augmentation with mask"
|
||||||
|
from timm.data import create_transform # timm can still be optional
|
||||||
|
if isinstance(image_size, (tuple, list)):
|
||||||
|
assert len(image_size) >= 2
|
||||||
|
input_size = (3,) + image_size[-2:]
|
||||||
|
else:
|
||||||
|
input_size = (3, image_size, image_size)
|
||||||
|
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
||||||
|
aug_cfg_dict.setdefault('interpolation', 'random')
|
||||||
|
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
||||||
|
train_transform = create_transform(
|
||||||
|
input_size=input_size,
|
||||||
|
is_training=True,
|
||||||
|
hflip=0.,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
re_mode='pixel',
|
||||||
|
**aug_cfg_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_transform = Compose([
|
||||||
|
_convert_to_rgb_or_rgba,
|
||||||
|
ToTensor(),
|
||||||
|
RandomResizedCrop(
|
||||||
|
image_size,
|
||||||
|
scale=aug_cfg_dict.pop('scale'),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
),
|
||||||
|
normalize,
|
||||||
|
])
|
||||||
|
if aug_cfg_dict:
|
||||||
|
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
||||||
|
return train_transform
|
||||||
|
else:
|
||||||
|
transforms = [
|
||||||
|
_convert_to_rgb_or_rgba,
|
||||||
|
ToTensor(),
|
||||||
|
]
|
||||||
|
if resize_longest_max:
|
||||||
|
transforms.extend([
|
||||||
|
ResizeMaxSize(image_size, fill=fill_color)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
transforms.extend([
|
||||||
|
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
||||||
|
CenterCrop(image_size),
|
||||||
|
])
|
||||||
|
transforms.extend([
|
||||||
|
normalize,
|
||||||
|
])
|
||||||
|
return Compose(transforms)
|
||||||
|
|
||||||
|
|
||||||
|
# def image_transform_region(
|
||||||
|
# image_size: int,
|
||||||
|
# is_train: bool,
|
||||||
|
# mean: Optional[Tuple[float, ...]] = None,
|
||||||
|
# std: Optional[Tuple[float, ...]] = None,
|
||||||
|
# resize_longest_max: bool = False,
|
||||||
|
# fill_color: int = 0,
|
||||||
|
# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
||||||
|
# ):
|
||||||
|
# mean = mean or OPENAI_DATASET_MEAN
|
||||||
|
# if not isinstance(mean, (list, tuple)):
|
||||||
|
# mean = (mean,) * 3
|
||||||
|
|
||||||
|
# std = std or OPENAI_DATASET_STD
|
||||||
|
# if not isinstance(std, (list, tuple)):
|
||||||
|
# std = (std,) * 3
|
||||||
|
|
||||||
|
# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
||||||
|
# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
||||||
|
# image_size = image_size[0]
|
||||||
|
|
||||||
|
# if isinstance(aug_cfg, dict):
|
||||||
|
# aug_cfg = AugmentationCfg(**aug_cfg)
|
||||||
|
# else:
|
||||||
|
# aug_cfg = aug_cfg or AugmentationCfg()
|
||||||
|
# normalize = Normalize(mean=mean, std=std)
|
||||||
|
# if is_train:
|
||||||
|
# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
||||||
|
|
||||||
|
# transform = Compose([
|
||||||
|
# RandomResizedCrop(
|
||||||
|
# image_size,
|
||||||
|
# scale=aug_cfg_dict.pop('scale'),
|
||||||
|
# interpolation=InterpolationMode.BICUBIC,
|
||||||
|
# ),
|
||||||
|
# ])
|
||||||
|
# train_transform = Compose([
|
||||||
|
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
|
||||||
|
# ])
|
||||||
|
# return train_transform
|
||||||
|
# else:
|
||||||
|
# if resize_longest_max:
|
||||||
|
# transform = [
|
||||||
|
# ResizeMaxSize(image_size, fill=fill_color)
|
||||||
|
# ]
|
||||||
|
# val_transform = Compose([
|
||||||
|
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
||||||
|
# ])
|
||||||
|
# else:
|
||||||
|
# transform = [
|
||||||
|
# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
||||||
|
# CenterCrop(image_size),
|
||||||
|
# ]
|
||||||
|
# val_transform = Compose([
|
||||||
|
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
|
||||||
|
# ])
|
||||||
|
# return val_transform
|
||||||
727
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
Normal file
727
diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
Normal file
@@ -0,0 +1,727 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
import math
|
||||||
|
from typing import Callable, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from .utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormFp32(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
return x.to(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
return x.to(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||||
|
|
||||||
|
|
||||||
|
class PatchDropout(nn.Module):
|
||||||
|
"""
|
||||||
|
https://arxiv.org/abs/2212.00794
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prob, exclude_first_token=True):
|
||||||
|
super().__init__()
|
||||||
|
assert 0 <= prob < 1.
|
||||||
|
self.prob = prob
|
||||||
|
self.exclude_first_token = exclude_first_token # exclude CLS token
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.training or self.prob == 0.:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.exclude_first_token:
|
||||||
|
cls_tokens, x = x[:, :1], x[:, 1:]
|
||||||
|
else:
|
||||||
|
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
||||||
|
|
||||||
|
batch = x.size()[0]
|
||||||
|
num_tokens = x.size()[1]
|
||||||
|
|
||||||
|
batch_indices = torch.arange(batch)
|
||||||
|
batch_indices = batch_indices[..., None]
|
||||||
|
|
||||||
|
keep_prob = 1 - self.prob
|
||||||
|
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
||||||
|
|
||||||
|
rand = torch.randn(batch, num_tokens)
|
||||||
|
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
||||||
|
|
||||||
|
x = x[batch_indices, patch_indices_keep]
|
||||||
|
|
||||||
|
if self.exclude_first_token:
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=True,
|
||||||
|
scaled_cosine=False,
|
||||||
|
scale_heads=False,
|
||||||
|
logit_scale_max=math.log(1. / 0.01),
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scaled_cosine = scaled_cosine
|
||||||
|
self.scale_heads = scale_heads
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.logit_scale_max = logit_scale_max
|
||||||
|
|
||||||
|
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
||||||
|
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
||||||
|
if qkv_bias:
|
||||||
|
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
||||||
|
else:
|
||||||
|
self.in_proj_bias = None
|
||||||
|
|
||||||
|
if self.scaled_cosine:
|
||||||
|
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
||||||
|
else:
|
||||||
|
self.logit_scale = None
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
if self.scale_heads:
|
||||||
|
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
||||||
|
else:
|
||||||
|
self.head_scale = None
|
||||||
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
|
self.out_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
L, N, C = x.shape
|
||||||
|
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
||||||
|
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||||
|
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||||
|
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
||||||
|
|
||||||
|
if self.logit_scale is not None:
|
||||||
|
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
||||||
|
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
||||||
|
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
||||||
|
attn = attn.view(-1, L, L)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = torch.bmm(q, k.transpose(-1, -2))
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.dtype == torch.bool:
|
||||||
|
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
||||||
|
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
||||||
|
attn_mask = new_attn_mask
|
||||||
|
attn += attn_mask
|
||||||
|
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = torch.bmm(attn, v)
|
||||||
|
if self.head_scale is not None:
|
||||||
|
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
||||||
|
x = x.view(-1, L, C)
|
||||||
|
x = x.transpose(0, 1).reshape(L, N, C)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
x = self.out_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionalPooler(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
context_dim: int,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_queries: int = 256,
|
||||||
|
norm_layer: Callable = LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
|
||||||
|
self.ln_q = norm_layer(d_model)
|
||||||
|
self.ln_k = norm_layer(context_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
||||||
|
N = x.shape[1]
|
||||||
|
q = self.ln_q(self.query)
|
||||||
|
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
|
||||||
|
return out.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
def _repeat(self, query, N: int):
|
||||||
|
return query.unsqueeze(1).repeat(1, N, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_head: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
is_cross_attention: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = norm_layer(d_model)
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||||
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
if is_cross_attention:
|
||||||
|
self.ln_1_kv = norm_layer(d_model)
|
||||||
|
|
||||||
|
self.ln_2 = norm_layer(d_model)
|
||||||
|
mlp_width = int(d_model * mlp_ratio)
|
||||||
|
self.mlp = nn.Sequential(OrderedDict([
|
||||||
|
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||||
|
("gelu", act_layer()),
|
||||||
|
("c_proj", nn.Linear(mlp_width, d_model))
|
||||||
|
]))
|
||||||
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
self,
|
||||||
|
q_x: torch.Tensor,
|
||||||
|
k_x: Optional[torch.Tensor] = None,
|
||||||
|
v_x: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
k_x = k_x if k_x is not None else q_x
|
||||||
|
v_x = v_x if v_x is not None else q_x
|
||||||
|
|
||||||
|
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
||||||
|
return self.attn(
|
||||||
|
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q_x: torch.Tensor,
|
||||||
|
k_x: Optional[torch.Tensor] = None,
|
||||||
|
v_x: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
||||||
|
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
||||||
|
|
||||||
|
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
||||||
|
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CustomResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_head: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
scale_cosine_attn: bool = False,
|
||||||
|
scale_heads: bool = False,
|
||||||
|
scale_attn: bool = False,
|
||||||
|
scale_fc: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = norm_layer(d_model)
|
||||||
|
self.attn = Attention(
|
||||||
|
d_model, n_head,
|
||||||
|
scaled_cosine=scale_cosine_attn,
|
||||||
|
scale_heads=scale_heads,
|
||||||
|
)
|
||||||
|
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
||||||
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
|
||||||
|
self.ln_2 = norm_layer(d_model)
|
||||||
|
mlp_width = int(d_model * mlp_ratio)
|
||||||
|
self.mlp = nn.Sequential(OrderedDict([
|
||||||
|
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||||
|
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
||||||
|
("gelu", act_layer()),
|
||||||
|
("c_proj", nn.Linear(mlp_width, d_model))
|
||||||
|
]))
|
||||||
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
||||||
|
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList([
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
|
for _ in range(layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_cast_dtype(self) -> torch.dtype:
|
||||||
|
return self.resblocks[0].mlp.c_fc.weight.dtype
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
for r in self.resblocks:
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||||
|
x = checkpoint(r, x, None, None, attn_mask)
|
||||||
|
else:
|
||||||
|
x = r(x, attn_mask=attn_mask)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
output_tokens: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
global_average_pool: bool = False,
|
||||||
|
attentional_pool: bool = False,
|
||||||
|
n_queries: int = 256,
|
||||||
|
attn_pooler_heads: int = 8,
|
||||||
|
output_dim: int = 512,
|
||||||
|
patch_dropout: float = 0.,
|
||||||
|
input_patchnorm: bool = False,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
output_tokens: bool = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
image_height, image_width = self.image_size = to_2tuple(image_size)
|
||||||
|
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
||||||
|
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
|
||||||
|
self.input_patchnorm = input_patchnorm
|
||||||
|
|
||||||
|
if input_patchnorm:
|
||||||
|
patch_input_dim = patch_height * patch_width * 3
|
||||||
|
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
|
||||||
|
self.conv1 = nn.Linear(patch_input_dim, width)
|
||||||
|
else:
|
||||||
|
self.patchnorm_pre_ln = nn.Identity()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||||
|
|
||||||
|
# class embeddings and positional embeddings
|
||||||
|
scale = width ** -0.5
|
||||||
|
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||||
|
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
||||||
|
|
||||||
|
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
||||||
|
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
||||||
|
|
||||||
|
self.ln_pre = norm_layer(width)
|
||||||
|
self.transformer = Transformer(
|
||||||
|
width,
|
||||||
|
layers,
|
||||||
|
heads,
|
||||||
|
mlp_ratio,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.global_average_pool = global_average_pool
|
||||||
|
if attentional_pool:
|
||||||
|
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
|
||||||
|
self.ln_post = norm_layer(output_dim)
|
||||||
|
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
|
||||||
|
else:
|
||||||
|
self.attn_pool = None
|
||||||
|
self.ln_post = norm_layer(width)
|
||||||
|
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||||
|
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if unlocked_groups != 0:
|
||||||
|
groups = [
|
||||||
|
[
|
||||||
|
self.conv1,
|
||||||
|
self.class_embedding,
|
||||||
|
self.positional_embedding,
|
||||||
|
self.ln_pre,
|
||||||
|
],
|
||||||
|
*self.transformer.resblocks[:-1],
|
||||||
|
[
|
||||||
|
self.transformer.resblocks[-1],
|
||||||
|
self.ln_post,
|
||||||
|
],
|
||||||
|
self.proj,
|
||||||
|
]
|
||||||
|
|
||||||
|
def _unlock(x):
|
||||||
|
if isinstance(x, Sequence):
|
||||||
|
for g in x:
|
||||||
|
_unlock(g)
|
||||||
|
else:
|
||||||
|
if isinstance(x, torch.nn.Parameter):
|
||||||
|
x.requires_grad = True
|
||||||
|
else:
|
||||||
|
for p in x.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
_unlock(groups[-unlocked_groups:])
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
|
||||||
|
# TODO experiment if default PyTorch init, below, or alternate init is best.
|
||||||
|
|
||||||
|
# nn.init.normal_(self.class_embedding, std=self.scale)
|
||||||
|
# nn.init.normal_(self.positional_embedding, std=self.scale)
|
||||||
|
#
|
||||||
|
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||||
|
# attn_std = self.transformer.width ** -0.5
|
||||||
|
# fc_std = (2 * self.transformer.width) ** -0.5
|
||||||
|
# for block in self.transformer.resblocks:
|
||||||
|
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
#
|
||||||
|
# if self.text_projection is not None:
|
||||||
|
# nn.init.normal_(self.text_projection, std=self.scale)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.transformer.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.global_average_pool:
|
||||||
|
return x.mean(dim=1), x
|
||||||
|
else:
|
||||||
|
return x[:, 0], x[:, 1:]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, skip_pool: bool = False):
|
||||||
|
|
||||||
|
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
||||||
|
if self.input_patchnorm:
|
||||||
|
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
|
||||||
|
x = x.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
|
||||||
|
x = self.patchnorm_pre_ln(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
else:
|
||||||
|
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
|
||||||
|
# class embeddings and positional embeddings
|
||||||
|
x = torch.cat(
|
||||||
|
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||||
|
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||||
|
x = x + self.positional_embedding.to(x.dtype)
|
||||||
|
|
||||||
|
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
||||||
|
x = self.patch_dropout(x)
|
||||||
|
x = self.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
if skip_pool:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.attn_pool is not None:
|
||||||
|
x = self.attn_pool(x)
|
||||||
|
x = self.ln_post(x)
|
||||||
|
pooled, tokens = self._global_pool(x)
|
||||||
|
else:
|
||||||
|
pooled, tokens = self._global_pool(x)
|
||||||
|
pooled = self.ln_post(pooled)
|
||||||
|
|
||||||
|
if self.proj is not None:
|
||||||
|
pooled = pooled @ self.proj
|
||||||
|
|
||||||
|
if self.output_tokens:
|
||||||
|
return pooled, tokens
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class TextTransformer(nn.Module):
|
||||||
|
output_tokens: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_length: int = 77,
|
||||||
|
vocab_size: int = 49408,
|
||||||
|
width: int = 512,
|
||||||
|
heads: int = 8,
|
||||||
|
layers: int = 12,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
output_dim: int = 512,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
embed_cls: bool = False,
|
||||||
|
pad_id: int = 0,
|
||||||
|
output_tokens: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
self.num_pos = self.context_length = context_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.width = width
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.heads = heads
|
||||||
|
self.pad_id = pad_id
|
||||||
|
|
||||||
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||||
|
|
||||||
|
if embed_cls:
|
||||||
|
self.cls_emb = nn.Parameter(torch.empty(width))
|
||||||
|
self.num_pos += 1
|
||||||
|
else:
|
||||||
|
self.cls_emb = None
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, width)
|
||||||
|
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
||||||
|
self.transformer = Transformer(
|
||||||
|
width=width,
|
||||||
|
layers=layers,
|
||||||
|
heads=heads,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
self.ln_final = norm_layer(width)
|
||||||
|
|
||||||
|
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
||||||
|
|
||||||
|
self.init_parameters()
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||||
|
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||||
|
if self.cls_emb is not None:
|
||||||
|
nn.init.normal_(self.cls_emb, std=0.01)
|
||||||
|
|
||||||
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||||
|
attn_std = self.transformer.width ** -0.5
|
||||||
|
fc_std = (2 * self.transformer.width) ** -0.5
|
||||||
|
for block in self.transformer.resblocks:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.transformer.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def build_attention_mask(self):
|
||||||
|
# lazily create causal attention mask, with full attention between the tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(self.num_pos, self.num_pos)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def build_cls_mask(self, text, cast_dtype: torch.dtype):
|
||||||
|
cls_mask = (text != self.pad_id).unsqueeze(1)
|
||||||
|
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
|
||||||
|
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
|
||||||
|
additive_mask.fill_(0)
|
||||||
|
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
||||||
|
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
||||||
|
return additive_mask
|
||||||
|
|
||||||
|
def _repeat(self, t, N: int):
|
||||||
|
return t.reshape(1, 1, -1).repeat(N, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
cast_dtype = self.transformer.get_cast_dtype()
|
||||||
|
seq_len = text.shape[1]
|
||||||
|
|
||||||
|
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||||
|
attn_mask = self.attn_mask
|
||||||
|
if self.cls_emb is not None:
|
||||||
|
seq_len += 1
|
||||||
|
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
|
||||||
|
cls_mask = self.build_cls_mask(text, cast_dtype)
|
||||||
|
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
|
||||||
|
|
||||||
|
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x, attn_mask=attn_mask)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
if self.cls_emb is not None:
|
||||||
|
pooled, tokens = x[:, -1], x[:, :-1]
|
||||||
|
pooled = self.ln_final(pooled)
|
||||||
|
else:
|
||||||
|
x = self.ln_final(x)
|
||||||
|
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
pooled = pooled @ self.text_projection
|
||||||
|
|
||||||
|
if self.output_tokens:
|
||||||
|
return pooled, tokens
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalTransformer(Transformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
context_length: int = 77,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
ls_init_value: float = None,
|
||||||
|
act_layer: Callable = nn.GELU,
|
||||||
|
norm_layer: Callable = LayerNorm,
|
||||||
|
output_dim: int = 512,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
width=width,
|
||||||
|
layers=layers,
|
||||||
|
heads=heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
self.context_length = context_length
|
||||||
|
self.cross_attn = nn.ModuleList([
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
width,
|
||||||
|
heads,
|
||||||
|
mlp_ratio,
|
||||||
|
ls_init_value=ls_init_value,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
is_cross_attention=True,
|
||||||
|
)
|
||||||
|
for _ in range(layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
||||||
|
|
||||||
|
self.ln_final = norm_layer(width)
|
||||||
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||||
|
attn_std = self.transformer.width ** -0.5
|
||||||
|
fc_std = (2 * self.transformer.width) ** -0.5
|
||||||
|
for block in self.transformer.resblocks:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
for block in self.transformer.cross_attn:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||||
|
|
||||||
|
def build_attention_mask(self):
|
||||||
|
# lazily create causal attention mask, with full attention between the tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(self.context_length, self.context_length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, image_embs, text_embs):
|
||||||
|
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
|
||||||
|
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
|
||||||
|
seq_len = text_embs.shape[0]
|
||||||
|
|
||||||
|
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||||
|
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
|
||||||
|
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
|
||||||
|
else:
|
||||||
|
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
|
||||||
|
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
|
||||||
|
|
||||||
|
x = text_embs.permute(1, 0, 2) # LND -> NLD
|
||||||
|
x = self.ln_final(x)
|
||||||
|
|
||||||
|
if self.text_projection is not None:
|
||||||
|
x = x @ self.text_projection
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
60
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
Normal file
60
diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from itertools import repeat
|
||||||
|
import collections.abc
|
||||||
|
|
||||||
|
from torch import nn as nn
|
||||||
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
||||||
|
"""
|
||||||
|
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
||||||
|
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
||||||
|
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): Any PyTorch module.
|
||||||
|
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
||||||
|
name (str): Full module name (prefix)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: Resulting module
|
||||||
|
|
||||||
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
||||||
|
"""
|
||||||
|
res = module
|
||||||
|
is_match = True
|
||||||
|
if module_match:
|
||||||
|
is_match = name in module_match
|
||||||
|
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
||||||
|
res = FrozenBatchNorm2d(module.num_features)
|
||||||
|
res.num_features = module.num_features
|
||||||
|
res.affine = module.affine
|
||||||
|
if module.affine:
|
||||||
|
res.weight.data = module.weight.data.clone().detach()
|
||||||
|
res.bias.data = module.bias.data.clone().detach()
|
||||||
|
res.running_mean.data = module.running_mean.data
|
||||||
|
res.running_var.data = module.running_var.data
|
||||||
|
res.eps = module.eps
|
||||||
|
else:
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
full_child_name = '.'.join([name, child_name]) if name else child_name
|
||||||
|
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
||||||
|
if new_child is not child:
|
||||||
|
res.add_module(child_name, new_child)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_1tuple = _ntuple(1)
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
to_3tuple = _ntuple(3)
|
||||||
|
to_4tuple = _ntuple(4)
|
||||||
|
to_ntuple = lambda n, x: _ntuple(n)(x)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '2.16.0'
|
||||||
112
diffsynth/extensions/ImageQualityMetric/pickscore.py
Normal file
112
diffsynth/extensions/ImageQualityMetric/pickscore.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, AutoModel
|
||||||
|
from typing import List, Union
|
||||||
|
import os
|
||||||
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
|
class PickScore(torch.nn.Module):
|
||||||
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
|
"""Initialize the Selector with a processor and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
|
"""
|
||||||
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
|
processor_name_or_path = path.get("clip")
|
||||||
|
model_pretrained_name_or_path = path.get("pickscore")
|
||||||
|
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
||||||
|
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
||||||
|
|
||||||
|
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
|
||||||
|
"""Calculate the score for a single image and prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): The processed image tensor.
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
softmax (bool): Whether to apply softmax to the scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The score for the image.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Prepare text inputs
|
||||||
|
text_inputs = self.processor(
|
||||||
|
text=prompt,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
# Embed images and text
|
||||||
|
image_embs = self.model.get_image_features(pixel_values=image)
|
||||||
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||||
|
text_embs = self.model.get_text_features(**text_inputs)
|
||||||
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Compute score
|
||||||
|
score = (text_embs @ image_embs.T)[0]
|
||||||
|
if softmax:
|
||||||
|
# Apply logit scale and softmax
|
||||||
|
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
|
||||||
|
|
||||||
|
return score.cpu().item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
||||||
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
|
prompt (str): The prompt text.
|
||||||
|
softmax (bool): Whether to apply softmax to the scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: List of scores for the images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
# Single image
|
||||||
|
if isinstance(images, str):
|
||||||
|
pil_image = Image.open(images)
|
||||||
|
else:
|
||||||
|
pil_image = images
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_image in images:
|
||||||
|
if isinstance(one_image, str):
|
||||||
|
pil_image = Image.open(one_image)
|
||||||
|
elif isinstance(one_image, Image.Image):
|
||||||
|
pil_image = one_image
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
|
||||||
|
# Prepare image inputs
|
||||||
|
image_inputs = self.processor(
|
||||||
|
images=pil_image,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .models import *
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .base_model import *
|
||||||
|
from .clip_model import *
|
||||||
|
from .cross_modeling import *
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelConfig:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers import CLIPModel as HFCLIPModel
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from torch import nn, einsum
|
||||||
|
|
||||||
|
from .base_model import BaseModelConfig
|
||||||
|
|
||||||
|
from transformers import CLIPConfig
|
||||||
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .cross_modeling import Cross_model
|
||||||
|
|
||||||
|
import json, os
|
||||||
|
|
||||||
|
class XCLIPModel(HFCLIPModel):
|
||||||
|
def __init__(self, config: CLIPConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def get_text_features(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
text_outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pooled_output = text_outputs[1]
|
||||||
|
# text_features = self.text_projection(pooled_output)
|
||||||
|
last_hidden_state = text_outputs[0]
|
||||||
|
text_features = self.text_projection(last_hidden_state)
|
||||||
|
|
||||||
|
pooled_output = text_outputs[1]
|
||||||
|
text_features_EOS = self.text_projection(pooled_output)
|
||||||
|
|
||||||
|
|
||||||
|
# del last_hidden_state, text_outputs
|
||||||
|
# gc.collect()
|
||||||
|
|
||||||
|
return text_features, text_features_EOS
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pooled_output = vision_outputs[1] # pooled_output
|
||||||
|
# image_features = self.visual_projection(pooled_output)
|
||||||
|
last_hidden_state = vision_outputs[0]
|
||||||
|
image_features = self.visual_projection(last_hidden_state)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipModelConfig(BaseModelConfig):
|
||||||
|
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
|
||||||
|
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPModel(nn.Module):
|
||||||
|
def __init__(self, ckpt, config_file=False):
|
||||||
|
super().__init__()
|
||||||
|
if config_file is None:
|
||||||
|
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||||
|
else:
|
||||||
|
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = CLIPConfig(**config)
|
||||||
|
self.model = XCLIPModel._from_config(config)
|
||||||
|
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
||||||
|
|
||||||
|
def get_text_features(self, *args, **kwargs):
|
||||||
|
return self.model.get_text_features(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_image_features(self, *args, **kwargs):
|
||||||
|
return self.model.get_image_features(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
||||||
|
outputs = ()
|
||||||
|
|
||||||
|
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
||||||
|
outputs += text_EOS,
|
||||||
|
|
||||||
|
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
||||||
|
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
||||||
|
|
||||||
|
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||||
|
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||||
|
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||||
|
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
||||||
|
|
||||||
|
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
||||||
|
bc = int(image_f.shape[0]/2)
|
||||||
|
|
||||||
|
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
||||||
|
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
||||||
|
outputs += sim0[:,0,:],
|
||||||
|
outputs += sim1[:,0,:],
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logit_scale(self):
|
||||||
|
return self.model.logit_scale
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
self.model.save_pretrained(path)
|
||||||
|
|
||||||
@@ -0,0 +1,292 @@
|
|||||||
|
import torch
|
||||||
|
from torch import einsum, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
# they use layernorm without bias, something that pytorch does not offer
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.register_buffer("bias", torch.zeros(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
||||||
|
|
||||||
|
# residual
|
||||||
|
|
||||||
|
|
||||||
|
class Residual(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return self.fn(x, *args, **kwargs) + x
|
||||||
|
|
||||||
|
|
||||||
|
# rotary positional embedding
|
||||||
|
# https://arxiv.org/abs/2104.09864
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
def forward(self, max_seq_len, *, device):
|
||||||
|
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
||||||
|
return torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
|
x1, x2 = x.unbind(dim=-2)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(pos, t):
|
||||||
|
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
||||||
|
|
||||||
|
|
||||||
|
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
||||||
|
# https://arxiv.org/abs/2002.05202
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return F.silu(gate) * x
|
||||||
|
|
||||||
|
|
||||||
|
# parallel attention and feedforward with residual
|
||||||
|
# discovered by Wang et al + EleutherAI from GPT-J fame
|
||||||
|
|
||||||
|
class ParallelTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNorm(dim)
|
||||||
|
|
||||||
|
attn_inner_dim = dim_head * heads
|
||||||
|
ff_inner_dim = dim * ff_mult
|
||||||
|
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
||||||
|
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
self.ff_out = nn.Sequential(
|
||||||
|
SwiGLU(),
|
||||||
|
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("pos_emb", None, persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_embedding(self, n, device):
|
||||||
|
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
||||||
|
return self.pos_emb[:n]
|
||||||
|
|
||||||
|
pos_emb = self.rotary_emb(n, device=device)
|
||||||
|
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
n, device, h = x.shape[1], x.device, self.heads
|
||||||
|
|
||||||
|
# pre layernorm
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# attention queries, keys, values, and feedforward inner
|
||||||
|
|
||||||
|
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||||
|
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||||
|
# https://arxiv.org/abs/1911.02150
|
||||||
|
|
||||||
|
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
positions = self.get_rotary_embedding(n, device)
|
||||||
|
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# similarity
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
|
||||||
|
# extra attention mask - for masking out attention from text CLS token to padding
|
||||||
|
|
||||||
|
if exists(attn_mask):
|
||||||
|
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
||||||
|
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
return self.attn_out(out) + self.ff_out(ff)
|
||||||
|
|
||||||
|
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
context_dim=None,
|
||||||
|
dim_head=64,
|
||||||
|
heads=12,
|
||||||
|
parallel_ff=False,
|
||||||
|
ff_mult=4,
|
||||||
|
norm_context=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = heads * dim_head
|
||||||
|
context_dim = default(context_dim, dim)
|
||||||
|
|
||||||
|
self.norm = LayerNorm(dim)
|
||||||
|
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
# whether to have parallel feedforward
|
||||||
|
|
||||||
|
ff_inner_dim = ff_mult * dim
|
||||||
|
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
||||||
|
SwiGLU(),
|
||||||
|
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||||
|
) if parallel_ff else None
|
||||||
|
|
||||||
|
def forward(self, x, context, mask):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pre-layernorm, for queries and context
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
context = self.context_norm(context)
|
||||||
|
|
||||||
|
# get queries
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# get key / values
|
||||||
|
|
||||||
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# query / key similarity
|
||||||
|
|
||||||
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
||||||
|
sim = sim + mask # context mask
|
||||||
|
sim = sim - sim.amax(dim=-1, keepdim=True)
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
|
# merge and combine heads
|
||||||
|
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
# add parallel feedforward (for multimodal layers)
|
||||||
|
|
||||||
|
if exists(self.ff):
|
||||||
|
out = out + self.ff(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Cross_model(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=512,
|
||||||
|
layer_num=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
|
||||||
|
for ind in range(layer_num):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
||||||
|
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query_tokens,
|
||||||
|
context_tokens,
|
||||||
|
mask
|
||||||
|
):
|
||||||
|
|
||||||
|
for cross_attn, self_attn_ff in self.layers:
|
||||||
|
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
||||||
|
query_tokens = self_attn_ff(query_tokens)
|
||||||
|
|
||||||
|
return query_tokens
|
||||||
45
diffsynth/lora/__init__.py
Normal file
45
diffsynth/lora/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralLoRALoader:
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_dict(self, lora_state_dict):
|
||||||
|
lora_name_dict = {}
|
||||||
|
for key in lora_state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
keys = key.split(".")
|
||||||
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
|
keys.pop(-1)
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
updated_num = 0
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in lora_name_dict:
|
||||||
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||||
|
module.load_state_dict(state_dict)
|
||||||
|
updated_num += 1
|
||||||
|
print(f"{updated_num} tensors are updated by LoRA.")
|
||||||
13
diffsynth/lora/flux_lora.py
Normal file
13
diffsynth/lora/flux_lora.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.lora import GeneralLoRALoader
|
||||||
|
from diffsynth.models.lora import FluxLoRAFromCivitai
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRALoader(GeneralLoRALoader):
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||||
|
self.loader = FluxLoRAFromCivitai()
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
|
||||||
|
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)
|
||||||
@@ -318,6 +318,10 @@ class FluxControlNetStateDictConverter:
|
|||||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||||
|
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||||
|
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
||||||
else:
|
else:
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
return state_dict_, extra_kwargs
|
return state_dict_, extra_kwargs
|
||||||
|
|||||||
@@ -276,21 +276,23 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FluxDiT(torch.nn.Module):
|
class FluxDiT(torch.nn.Module):
|
||||||
def __init__(self, disable_guidance_embedder=False):
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||||
|
|
||||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
|
||||||
|
|
||||||
def patchify(self, hidden_states):
|
def patchify(self, hidden_states):
|
||||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||||
@@ -628,19 +630,22 @@ class FluxDiTStateDictConverter:
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
for name in list(state_dict_.keys()):
|
for name in list(state_dict_.keys()):
|
||||||
if ".proj_in_besides_attn." in name:
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
param = torch.concat([
|
param = torch.concat([
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
state_dict_.pop(name),
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
state_dict_[name],
|
mlp,
|
||||||
], dim=0)
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
state_dict_[name_] = param
|
state_dict_[name_] = param
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
|
||||||
state_dict_.pop(name)
|
|
||||||
for name in list(state_dict_.keys()):
|
for name in list(state_dict_.keys()):
|
||||||
for component in ["a", "b"]:
|
for component in ["a", "b"]:
|
||||||
if f".{component}_to_q." in name:
|
if f".{component}_to_q." in name:
|
||||||
@@ -735,5 +740,7 @@ class FluxDiTStateDictConverter:
|
|||||||
pass
|
pass
|
||||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||||
return state_dict_, {"disable_guidance_embedder": True}
|
return state_dict_, {"disable_guidance_embedder": True}
|
||||||
|
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
|
||||||
|
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||||
else:
|
else:
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
128
diffsynth/models/flux_infiniteyou.py
Normal file
128
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
def FeedForward(dim, mult=4):
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_tensor(x, heads):
|
||||||
|
bs, length, width = x.shape
|
||||||
|
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
|
x = x.view(bs, length, heads, -1)
|
||||||
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||||
|
x = x.reshape(bs, heads, length, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, latents):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): image features
|
||||||
|
shape (b, n1, D)
|
||||||
|
latent (torch.Tensor): latent features
|
||||||
|
shape (b, n2, D)
|
||||||
|
"""
|
||||||
|
x = self.norm1(x)
|
||||||
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
|
q = self.to_q(latents)
|
||||||
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
q = reshape_tensor(q, self.heads)
|
||||||
|
k = reshape_tensor(k, self.heads)
|
||||||
|
v = reshape_tensor(v, self.heads)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
out = weight @ v
|
||||||
|
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteYouImageProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=1280,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=20,
|
||||||
|
num_queries=8,
|
||||||
|
embedding_dim=512,
|
||||||
|
output_dim=4096,
|
||||||
|
ff_mult=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||||
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim)
|
||||||
|
self.norm_out = nn.LayerNorm(output_dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(x, latents) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
return self.norm_out(latents)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict['image_proj']
|
||||||
58
diffsynth/models/flux_value_control.py
Normal file
58
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||||
|
|
||||||
|
|
||||||
|
class MultiValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, encoders=()):
|
||||||
|
super().__init__()
|
||||||
|
self.encoders = torch.nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
def __call__(self, values, dtype):
|
||||||
|
emb = []
|
||||||
|
for encoder, value in zip(self.encoders, values):
|
||||||
|
if value is not None:
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
emb.append(encoder(value, dtype))
|
||||||
|
emb = torch.concat(emb, dim=0)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.prefer_len = prefer_len
|
||||||
|
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||||
|
self.prefer_value_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(
|
||||||
|
torch.randn(self.prefer_len, dim_in)
|
||||||
|
)
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
last_linear = self.prefer_value_embedder[-1]
|
||||||
|
torch.nn.init.zeros_(last_linear.weight)
|
||||||
|
torch.nn.init.zeros_(last_linear.bias)
|
||||||
|
|
||||||
|
def forward(self, value, dtype):
|
||||||
|
emb = self.prefer_proj(value).to(dtype)
|
||||||
|
emb = emb.expand(self.prefer_len, -1)
|
||||||
|
emb = emb + self.positional_embedding
|
||||||
|
emb = self.prefer_value_embedder(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SingleValueEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Union, Tuple, List
|
from typing import Union, Tuple, List
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
|
||||||
|
|
||||||
def HunyuanVideoRope(latents):
|
def HunyuanVideoRope(latents):
|
||||||
@@ -281,7 +282,12 @@ class ModulateDiT(torch.nn.Module):
|
|||||||
return self.linear(self.act(x))
|
return self.linear(self.act(x))
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, shift=None, scale=None):
|
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||||
|
if tr_shift is not None:
|
||||||
|
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||||
|
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
x = torch.concat((x_zero, x_orig), dim=1)
|
||||||
|
return x
|
||||||
if scale is None and shift is None:
|
if scale is None and shift is None:
|
||||||
return x
|
return x
|
||||||
elif shift is None:
|
elif shift is None:
|
||||||
@@ -385,6 +391,15 @@ def attention(q, k, v):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||||
|
if tr_gate is not None:
|
||||||
|
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||||
|
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||||
|
return torch.concat((x_zero, x_orig), dim=1)
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
|||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||||
|
if token_replace_vec is not None:
|
||||||
|
assert tr_token is not None
|
||||||
|
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||||
|
else:
|
||||||
|
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||||
|
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||||
|
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
qkv = self.to_qkv(norm_hidden_states)
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
|
||||||
@@ -418,13 +439,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
|||||||
|
|
||||||
if freqs_cis is not None:
|
if freqs_cis is not None:
|
||||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||||
|
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||||
|
|
||||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||||
|
|
||||||
def process_ff(self, hidden_states, attn_output, mod):
|
|
||||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
if mod_tr is not None:
|
||||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||||
|
else:
|
||||||
|
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||||
|
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||||
|
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||||
|
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
|
|||||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
attn_output_a = attention(q_a, k_a, v_a)
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
attn_output_b = attention(q_b, k_b, v_b)
|
||||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||||
|
|
||||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||||
return hidden_states_a, hidden_states_b
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
|||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||||
|
if token_replace_vec is not None:
|
||||||
|
assert tr_token is not None
|
||||||
|
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||||
|
else:
|
||||||
|
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||||
|
|
||||||
norm_hidden_states = self.norm(hidden_states)
|
norm_hidden_states = self.norm(hidden_states)
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||||
|
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
qkv = self.to_qkv(norm_hidden_states)
|
||||||
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
|||||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
v_len = txt_len - split_token
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||||
|
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||||
|
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
attn_output_a = attention(q_a, k_a, v_a)
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
attn_output_b = attention(q_b, k_b, v_b)
|
||||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiT(torch.nn.Module):
|
class HunyuanVideoDiT(torch.nn.Module):
|
||||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||||
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
torch.nn.SiLU(),
|
torch.nn.SiLU(),
|
||||||
torch.nn.Linear(hidden_size, hidden_size)
|
torch.nn.Linear(hidden_size, hidden_size)
|
||||||
)
|
)
|
||||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||||
self.final_layer = FinalLayer(hidden_size)
|
self.final_layer = FinalLayer(hidden_size)
|
||||||
@@ -610,7 +642,9 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
B, C, T, H, W = x.shape
|
B, C, T, H, W = x.shape
|
||||||
|
|
||||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||||
|
if self.guidance_in is not None:
|
||||||
|
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||||
img = self.img_in(x)
|
img = self.img_in(x)
|
||||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||||
|
|
||||||
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
return HunyuanVideoDiTStateDictConverter()
|
return HunyuanVideoDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiTStateDictConverter:
|
class HunyuanVideoDiTStateDictConverter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
|
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
if "module" in state_dict:
|
if "module" in state_dict:
|
||||||
state_dict = state_dict["module"]
|
state_dict = state_dict["module"]
|
||||||
direct_dict = {
|
direct_dict = {
|
||||||
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
|
|||||||
state_dict_[name_] = param
|
state_dict_[name_] = param
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
@@ -1,24 +1,18 @@
|
|||||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.auto_offload = False
|
self.auto_offload = False
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, **kwargs):
|
def enable_auto_offload(self, **kwargs):
|
||||||
self.auto_offload = True
|
self.auto_offload = True
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
hidden_state_skip_layer=2
|
|
||||||
):
|
|
||||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||||
inputs_embeds = embed_tokens(input_ids)
|
inputs_embeds = embed_tokens(input_ids)
|
||||||
|
|
||||||
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
|
|||||||
break
|
break
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.auto_offload = False
|
||||||
|
|
||||||
|
def enable_auto_offload(self, **kwargs):
|
||||||
|
self.auto_offload = True
|
||||||
|
|
||||||
|
# TODO: implement the low VRAM inference for MLLM.
|
||||||
|
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||||
|
outputs = super().forward(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
pixel_values=pixel_values)
|
||||||
|
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||||
|
return hidden_state
|
||||||
|
|||||||
@@ -73,7 +73,6 @@ try:
|
|||||||
)
|
)
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
kernels = None
|
kernels = None
|
||||||
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
|
||||||
|
|
||||||
|
|
||||||
class W8A16Linear(torch.autograd.Function):
|
class W8A16Linear(torch.autograd.Function):
|
||||||
@@ -981,7 +980,7 @@ class Embedding(torch.nn.Module):
|
|||||||
# Embeddings.
|
# Embeddings.
|
||||||
words_embeddings = self.word_embeddings(input_ids)
|
words_embeddings = self.word_embeddings(input_ids)
|
||||||
embeddings = words_embeddings
|
embeddings = words_embeddings
|
||||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
|
||||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||||
# If the input flag for fp32 residual connection is set, convert for float.
|
# If the input flag for fp32 residual connection is set, convert for float.
|
||||||
if self.fp32_residual_connection:
|
if self.fp32_residual_connection:
|
||||||
@@ -1374,7 +1373,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .flux_dit import FluxDiT
|
|||||||
from .hunyuan_dit import HunyuanDiT
|
from .hunyuan_dit import HunyuanDiT
|
||||||
from .cog_dit import CogDiT
|
from .cog_dit import CogDiT
|
||||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||||
|
from .wan_video_dit import WanModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -195,70 +196,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
class GeneralLoRAFromPeft:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
|
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||||
|
|
||||||
|
|
||||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
def get_name_dict(self, lora_state_dict):
|
||||||
device, torch_dtype = None, None
|
lora_name_dict = {}
|
||||||
for name, param in state_dict.items():
|
for key in lora_state_dict:
|
||||||
device, torch_dtype = param.device, param.dtype
|
|
||||||
break
|
|
||||||
return device, torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
|
||||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
if ".lora_B." not in key:
|
||||||
continue
|
continue
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
keys = key.split(".")
|
||||||
if len(keys) > keys.index("lora_B") + 2:
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
keys.pop(keys.index("lora_B") + 1)
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
keys.pop(keys.index("lora_B"))
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
target_name = ".".join(keys)
|
target_name = ".".join(keys)
|
||||||
if target_name not in target_state_dict:
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
return {}
|
return lora_name_dict
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||||
|
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||||
|
if matched_num == len(lora_name_dict):
|
||||||
|
return "", ""
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_device_and_dtype(self, state_dict):
|
||||||
|
device, dtype = None, None
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
device, dtype = param.device, param.dtype
|
||||||
|
break
|
||||||
|
computation_device = device
|
||||||
|
computation_dtype = dtype
|
||||||
|
if computation_device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
computation_device = torch.device("cuda")
|
||||||
|
if computation_dtype == torch.float8_e4m3fn:
|
||||||
|
computation_dtype = torch.float32
|
||||||
|
return device, dtype, computation_device, computation_dtype
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||||
state_dict_model = model.state_dict()
|
state_dict_model = model.state_dict()
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||||
if len(state_dict_lora) > 0:
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
for name in lora_name_dict:
|
||||||
for name in state_dict_lora:
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||||
dtype=state_dict_model[name].dtype,
|
if len(weight_up.shape) == 4:
|
||||||
device=state_dict_model[name].device
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
)
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
model.load_state_dict(state_dict_model)
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||||
|
weight_patched = weight_model + weight_lora
|
||||||
|
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||||
|
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||||
|
model.load_state_dict(state_dict_model)
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for model_class in self.supported_model_classes:
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora_) > 0:
|
|
||||||
return "", ""
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -363,5 +367,20 @@ class FluxLoRAConverter:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class WanLoRAConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_opensource_format(state_dict, **kwargs):
|
||||||
|
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
@@ -69,7 +69,9 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
|||||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||||
with init_weights_on_device():
|
with init_weights_on_device():
|
||||||
model= model_class(**extra_kwargs)
|
model = model_class(**extra_kwargs)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
model.load_state_dict(model_state_dict, assign=True)
|
model.load_state_dict(model_state_dict, assign=True)
|
||||||
model = model.to(dtype=torch_dtype, device=device)
|
model = model.to(dtype=torch_dtype, device=device)
|
||||||
loaded_model_names.append(model_name)
|
loaded_model_names.append(model_name)
|
||||||
@@ -374,6 +376,7 @@ class ModelManager:
|
|||||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||||
else:
|
else:
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
print(f"Loading LoRA models from file: {file_path}")
|
||||||
|
is_loaded = False
|
||||||
if len(state_dict) == 0:
|
if len(state_dict) == 0:
|
||||||
state_dict = load_state_dict(file_path)
|
state_dict = load_state_dict(file_path)
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||||
@@ -383,7 +386,10 @@ class ModelManager:
|
|||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||||
lora_prefix, model_resource = match_results
|
lora_prefix, model_resource = match_results
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||||
|
is_loaded = True
|
||||||
break
|
break
|
||||||
|
if not is_loaded:
|
||||||
|
print(f" Cannot load LoRA: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||||
|
|||||||
168
diffsynth/models/qwenvl.py
Normal file
168
diffsynth/models/qwenvl.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
||||||
|
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||||
|
super(Qwen25VL_7b_Embedder, self).__init__()
|
||||||
|
self.max_length = max_length
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
).to(torch.cuda.current_device())
|
||||||
|
|
||||||
|
self.model.requires_grad_(False)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
||||||
|
)
|
||||||
|
|
||||||
|
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
||||||
|
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
||||||
|
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
||||||
|
Here are examples of how to transform or refine prompts:
|
||||||
|
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
||||||
|
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
||||||
|
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
||||||
|
User Prompt:'''
|
||||||
|
|
||||||
|
self.prefix = Qwen25VL_7b_PREFIX
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
|
||||||
|
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, caption, ref_images):
|
||||||
|
text_list = caption
|
||||||
|
embs = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
masks = torch.zeros(
|
||||||
|
len(text_list),
|
||||||
|
self.max_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
input_ids_list = []
|
||||||
|
attention_mask_list = []
|
||||||
|
emb_list = []
|
||||||
|
|
||||||
|
def split_string(s):
|
||||||
|
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
||||||
|
result = []
|
||||||
|
in_quotes = False
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
for idx,char in enumerate(s):
|
||||||
|
if char == '"' and idx>155:
|
||||||
|
temp += char
|
||||||
|
if not in_quotes:
|
||||||
|
result.append(temp)
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
in_quotes = not in_quotes
|
||||||
|
continue
|
||||||
|
if in_quotes:
|
||||||
|
if char.isspace():
|
||||||
|
pass # have space token
|
||||||
|
|
||||||
|
result.append("“" + char + "”")
|
||||||
|
else:
|
||||||
|
temp += char
|
||||||
|
|
||||||
|
if temp:
|
||||||
|
result.append(temp)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": []}]
|
||||||
|
|
||||||
|
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
||||||
|
|
||||||
|
messages[0]["content"].append({"type": "image", "image": imgs})
|
||||||
|
|
||||||
|
# 再添加 text
|
||||||
|
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
||||||
|
|
||||||
|
# Preparation for inference
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs = [imgs]
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=image_inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
old_inputs_ids = inputs.input_ids
|
||||||
|
text_split_list = split_string(text)
|
||||||
|
|
||||||
|
token_list = []
|
||||||
|
for text_each in text_split_list:
|
||||||
|
txt_inputs = self.processor(
|
||||||
|
text=text_each,
|
||||||
|
images=None,
|
||||||
|
videos=None,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
token_each = txt_inputs.input_ids
|
||||||
|
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
||||||
|
token_each = token_each[:, 1:-1]
|
||||||
|
token_list.append(token_each)
|
||||||
|
else:
|
||||||
|
token_list.append(token_each)
|
||||||
|
|
||||||
|
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||||
|
|
||||||
|
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||||
|
|
||||||
|
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||||
|
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||||
|
inputs.input_ids = (
|
||||||
|
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to("cuda")
|
||||||
|
)
|
||||||
|
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=inputs.input_ids,
|
||||||
|
attention_mask=inputs.attention_mask,
|
||||||
|
pixel_values=inputs.pixel_values.to("cuda"),
|
||||||
|
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
emb = outputs["hidden_states"][-1]
|
||||||
|
|
||||||
|
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
||||||
|
: self.max_length
|
||||||
|
]
|
||||||
|
|
||||||
|
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||||
|
(min(self.max_length, emb.shape[1] - 217)),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return embs, masks
|
||||||
683
diffsynth/models/step1x_connector.py
Normal file
683
diffsynth/models/step1x_connector.py
Normal file
@@ -0,0 +1,683 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch, math
|
||||||
|
import torch.nn
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from functools import partial
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v, attn_mask, mode="torch"):
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
hidden_channels=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=None,
|
||||||
|
bias=True,
|
||||||
|
drop=0.0,
|
||||||
|
use_conv=False,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_channels
|
||||||
|
hidden_channels = hidden_channels or in_channels
|
||||||
|
bias = (bias, bias)
|
||||||
|
drop_probs = (drop, drop)
|
||||||
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||||
|
|
||||||
|
self.fc1 = linear_layer(
|
||||||
|
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.norm = (
|
||||||
|
norm_layer(hidden_channels, device=device, dtype=dtype)
|
||||||
|
if norm_layer is not None
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.fc2 = linear_layer(
|
||||||
|
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextProjection(nn.Module):
|
||||||
|
"""
|
||||||
|
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
||||||
|
|
||||||
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
|
in_features=in_channels,
|
||||||
|
out_features=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.act_1 = act_layer()
|
||||||
|
self.linear_2 = nn.Linear(
|
||||||
|
in_features=hidden_size,
|
||||||
|
out_features=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
hidden_states = self.linear_1(caption)
|
||||||
|
hidden_states = self.act_1(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
act_layer,
|
||||||
|
frequency_embedding_size=256,
|
||||||
|
max_period=10000,
|
||||||
|
out_size=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.max_period = max_period
|
||||||
|
if out_size is None:
|
||||||
|
out_size = hidden_size
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
||||||
|
),
|
||||||
|
act_layer(),
|
||||||
|
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||||
|
)
|
||||||
|
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
||||||
|
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||||
|
dim (int): the dimension of the output.
|
||||||
|
max_period (int): controls the minimum frequency of the embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
||||||
|
|
||||||
|
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||||
|
/ half
|
||||||
|
).to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(
|
||||||
|
t, self.frequency_embedding_size, self.max_period
|
||||||
|
).type(self.mlp[0].weight.dtype) # type: ignore
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate=None, tanh=False):
|
||||||
|
"""AI is creating summary for apply_gate
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input tensor.
|
||||||
|
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
||||||
|
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the output tensor after apply gate.
|
||||||
|
"""
|
||||||
|
if gate is None:
|
||||||
|
return x
|
||||||
|
if tanh:
|
||||||
|
return x * gate.unsqueeze(1).tanh()
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
|
||||||
|
"""
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
"""
|
||||||
|
Apply the RMSNorm normalization to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the RMSNorm layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tensor after applying RMSNorm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if hasattr(self, "weight"):
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(norm_layer):
|
||||||
|
"""
|
||||||
|
Get the normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_layer (str): The type of normalization layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
norm_layer (nn.Module): The normalization layer.
|
||||||
|
"""
|
||||||
|
if norm_layer == "layer":
|
||||||
|
return nn.LayerNorm
|
||||||
|
elif norm_layer == "rms":
|
||||||
|
return RMSNorm
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_layer(act_type):
|
||||||
|
"""get activation layer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
act_type (str): the activation type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.functional: the activation layer
|
||||||
|
"""
|
||||||
|
if act_type == "gelu":
|
||||||
|
return lambda: nn.GELU()
|
||||||
|
elif act_type == "gelu_tanh":
|
||||||
|
return lambda: nn.GELU(approximate="tanh")
|
||||||
|
elif act_type == "relu":
|
||||||
|
return nn.ReLU
|
||||||
|
elif act_type == "silu":
|
||||||
|
return nn.SiLU
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation type: {act_type}")
|
||||||
|
|
||||||
|
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
mlp_width_ratio: str = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
need_CA: bool = False,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.need_CA = need_CA
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.self_attn_qkv = nn.Linear(
|
||||||
|
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||||
|
self.self_attn_q_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_k_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_proj = nn.Linear(
|
||||||
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
act_layer = get_activation_layer(act_type)
|
||||||
|
self.mlp = MLP(
|
||||||
|
in_channels=hidden_size,
|
||||||
|
hidden_channels=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=mlp_drop_rate,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
act_layer(),
|
||||||
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.need_CA:
|
||||||
|
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
||||||
|
heads_num=heads_num,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
mlp_drop_rate=mlp_drop_rate,
|
||||||
|
act_type=act_type,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
**factory_kwargs,)
|
||||||
|
# Zero-initialize the modulation
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
y: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
qkv = self.self_attn_qkv(norm_x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
# Apply QK-Norm if needed
|
||||||
|
q = self.self_attn_q_norm(q).to(v)
|
||||||
|
k = self.self_attn_k_norm(k).to(v)
|
||||||
|
|
||||||
|
# Self-Attention
|
||||||
|
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||||
|
|
||||||
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||||
|
|
||||||
|
if self.need_CA:
|
||||||
|
x = self.cross_attnblock(x, c, attn_mask, y)
|
||||||
|
|
||||||
|
# FFN Layer
|
||||||
|
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttnBlock(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
mlp_width_ratio: str = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.norm1_2 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.self_attn_q = nn.Linear(
|
||||||
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.self_attn_kv = nn.Linear(
|
||||||
|
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||||
|
self.self_attn_q_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_k_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
if qk_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.self_attn_proj = nn.Linear(
|
||||||
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||||
|
)
|
||||||
|
act_layer = get_activation_layer(act_type)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
act_layer(),
|
||||||
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||||
|
)
|
||||||
|
# Zero-initialize the modulation
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
y: torch.Tensor=None,
|
||||||
|
|
||||||
|
):
|
||||||
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
norm_y = self.norm1_2(y)
|
||||||
|
q = self.self_attn_q(norm_x)
|
||||||
|
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||||
|
kv = self.self_attn_kv(norm_y)
|
||||||
|
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
||||||
|
# Apply QK-Norm if needed
|
||||||
|
q = self.self_attn_q_norm(q).to(v)
|
||||||
|
k = self.self_attn_k_norm(k).to(v)
|
||||||
|
|
||||||
|
# Self-Attention
|
||||||
|
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||||
|
|
||||||
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualTokenRefiner(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
depth,
|
||||||
|
mlp_width_ratio: float = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
need_CA:bool=False,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.need_CA = need_CA
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
IndividualTokenRefinerBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
heads_num=heads_num,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
mlp_drop_rate=mlp_drop_rate,
|
||||||
|
act_type=act_type,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
need_CA=self.need_CA,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.LongTensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
y:torch.Tensor=None,
|
||||||
|
):
|
||||||
|
self_attn_mask = None
|
||||||
|
if mask is not None:
|
||||||
|
batch_size = mask.shape[0]
|
||||||
|
seq_len = mask.shape[1]
|
||||||
|
mask = mask.to(x.device)
|
||||||
|
# batch_size x 1 x seq_len x seq_len
|
||||||
|
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
||||||
|
1, 1, seq_len, 1
|
||||||
|
)
|
||||||
|
# batch_size x 1 x seq_len x seq_len
|
||||||
|
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||||
|
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
||||||
|
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||||
|
# avoids self-attention weight being NaN for padding tokens
|
||||||
|
self_attn_mask[:, :, :, 0] = True
|
||||||
|
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, c, self_attn_mask,y)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleTokenRefiner(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A single token refiner block for llm text embedding refine.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
heads_num,
|
||||||
|
depth,
|
||||||
|
mlp_width_ratio: float = 4.0,
|
||||||
|
mlp_drop_rate: float = 0.0,
|
||||||
|
act_type: str = "silu",
|
||||||
|
qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "layer",
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
need_CA:bool=False,
|
||||||
|
attn_mode: str = "torch",
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.need_CA = need_CA
|
||||||
|
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
||||||
|
|
||||||
|
self.input_embedder = nn.Linear(
|
||||||
|
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||||
|
)
|
||||||
|
if self.need_CA:
|
||||||
|
self.input_embedder_CA = nn.Linear(
|
||||||
|
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
act_layer = get_activation_layer(act_type)
|
||||||
|
# Build timestep embedding layer
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
||||||
|
# Build context embedding layer
|
||||||
|
self.c_embedder = TextProjection(
|
||||||
|
in_channels, hidden_size, act_layer, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.individual_token_refiner = IndividualTokenRefiner(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
heads_num=heads_num,
|
||||||
|
depth=depth,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
mlp_drop_rate=mlp_drop_rate,
|
||||||
|
act_type=act_type,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
need_CA=need_CA,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.LongTensor,
|
||||||
|
mask: Optional[torch.LongTensor] = None,
|
||||||
|
y: torch.LongTensor=None,
|
||||||
|
):
|
||||||
|
timestep_aware_representations = self.t_embedder(t)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
context_aware_representations = x.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||||
|
context_aware_representations = (x * mask_float).sum(
|
||||||
|
dim=1
|
||||||
|
) / mask_float.sum(dim=1)
|
||||||
|
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||||
|
c = timestep_aware_representations + context_aware_representations
|
||||||
|
|
||||||
|
x = self.input_embedder(x)
|
||||||
|
if self.need_CA:
|
||||||
|
y = self.input_embedder_CA(y)
|
||||||
|
x = self.individual_token_refiner(x, c, mask, y)
|
||||||
|
else:
|
||||||
|
x = self.individual_token_refiner(x, c, mask)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Connector(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# biclip_dim=1024,
|
||||||
|
in_channels=3584,
|
||||||
|
hidden_size=4096,
|
||||||
|
heads_num=32,
|
||||||
|
depth=2,
|
||||||
|
need_CA=False,
|
||||||
|
device=None,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
factory_kwargs = {"device": device, "dtype":dtype}
|
||||||
|
|
||||||
|
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
||||||
|
self.global_proj_out=nn.Linear(in_channels,768)
|
||||||
|
|
||||||
|
self.scale_factor = nn.Parameter(torch.zeros(1))
|
||||||
|
with torch.no_grad():
|
||||||
|
self.scale_factor.data += -(1 - 0.09)
|
||||||
|
|
||||||
|
def forward(self, x,t,mask):
|
||||||
|
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||||
|
x_mean = (x * mask_float).sum(
|
||||||
|
dim=1
|
||||||
|
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
||||||
|
|
||||||
|
global_out=self.global_proj_out(x_mean)
|
||||||
|
encoder_hidden_states = self.S(x,t,mask)
|
||||||
|
return encoder_hidden_states,global_out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return Qwen2ConnectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ConnectorStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("connector."):
|
||||||
|
name_ = name[len("connector."):]
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
@@ -10,7 +10,7 @@
|
|||||||
# The above copyright notice and this permission notice shall be included in all
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
# copies or substantial portions of the Software.
|
# copies or substantial portions of the Software.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple, Union, List
|
||||||
import torch, math
|
import torch, math
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@@ -398,7 +398,7 @@ class RoPE1D:
|
|||||||
* tokens: batch_size x ntokens x nheads x dim
|
* tokens: batch_size x ntokens x nheads x dim
|
||||||
* positions: batch_size x ntokens (t position of each token)
|
* positions: batch_size x ntokens (t position of each token)
|
||||||
output:
|
output:
|
||||||
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
|
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||||
"""
|
"""
|
||||||
D = tokens.size(3)
|
D = tokens.size(3)
|
||||||
assert positions.ndim == 2 # Batch, Seq
|
assert positions.ndim == 2 # Batch, Seq
|
||||||
@@ -428,7 +428,7 @@ class RoPE3D(RoPE1D):
|
|||||||
* tokens: batch_size x ntokens x nheads x dim
|
* tokens: batch_size x ntokens x nheads x dim
|
||||||
* rope_positions: list of (f, h, w)
|
* rope_positions: list of (f, h, w)
|
||||||
output:
|
output:
|
||||||
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
|
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||||
"""
|
"""
|
||||||
assert sum(ch_split) == tokens.size(-1);
|
assert sum(ch_split) == tokens.size(-1);
|
||||||
|
|
||||||
@@ -757,7 +757,7 @@ class StepVideoModel(torch.nn.Module):
|
|||||||
norm_elementwise_affine: bool = False,
|
norm_elementwise_affine: bool = False,
|
||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
use_additional_conditions: Optional[bool] = False,
|
use_additional_conditions: Optional[bool] = False,
|
||||||
caption_channels: Optional[int]|list|tuple = [6144, 1024],
|
caption_channels: Optional[Union[int, List, Tuple]] = [6144, 1024],
|
||||||
attention_type: Optional[str] = "torch",
|
attention_type: Optional[str] = "torch",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class LLaMaEmbedding(nn.Module):
|
|||||||
embeddings = embeddings.to(self.params_dtype)
|
embeddings = embeddings.to(self.params_dtype)
|
||||||
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
|
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
|
||||||
|
|
||||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
|
||||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
# If the input flag for fp32 residual connection is set, convert for float.
|
# If the input flag for fp32 residual connection is set, convert for float.
|
||||||
@@ -326,7 +326,7 @@ class MultiQueryAttention(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather on 1st dimention
|
# gather on 1st dimension
|
||||||
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
|
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
|
||||||
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
|
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
|
||||||
xk, xv = xkv.chunk(2, -1)
|
xk, xv = xkv.chunk(2, -1)
|
||||||
@@ -357,7 +357,7 @@ class MultiQueryAttention(nn.Module):
|
|||||||
output = self.core_attention(xq, xk, xv,
|
output = self.core_attention(xq, xk, xv,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seq_len=max_seq_len)
|
max_seq_len=max_seq_len)
|
||||||
# reduce-scatter only support first dimention now
|
# reduce-scatter only support first dimension now
|
||||||
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
||||||
else:
|
else:
|
||||||
xq, xk, xv = [
|
xq, xk, xv = [
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class TileWorker:
|
|||||||
|
|
||||||
|
|
||||||
def io_scale(self, model_output, tile_size):
|
def io_scale(self, model_output, tile_size):
|
||||||
# Determine the size modification happend in forward_fn
|
# Determine the size modification happened in forward_fn
|
||||||
# We only consider the same scale on height and width.
|
# We only consider the same scale on height and width.
|
||||||
io_scale = model_output.shape[2] / tile_size
|
io_scale = model_output.shape[2] / tile_size
|
||||||
return io_scale
|
return io_scale
|
||||||
|
|||||||
@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None):
|
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||||
if file_path.endswith(".safetensors"):
|
if file_path.endswith(".safetensors"):
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
else:
|
else:
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
with safe_open(file_path, framework="pt", device=device) as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
state_dict[k] = f.get_tensor(k)
|
state_dict[k] = f.get_tensor(k)
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
for i in state_dict:
|
for i in state_dict:
|
||||||
if isinstance(state_dict[i], torch.Tensor):
|
if isinstance(state_dict[i], torch.Tensor):
|
||||||
|
|||||||
202
diffsynth/models/wan_video_camera_controller.py
Normal file
202
diffsynth/models/wan_video_camera_controller.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
import os
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
class SimpleAdapter(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
||||||
|
super(SimpleAdapter, self).__init__()
|
||||||
|
|
||||||
|
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||||
|
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||||
|
|
||||||
|
# Convolution: reduce spatial dimensions by a factor
|
||||||
|
# of 2 (without overlap)
|
||||||
|
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
|
||||||
|
|
||||||
|
# Residual blocks for feature extraction
|
||||||
|
self.residual_blocks = nn.Sequential(
|
||||||
|
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Reshape to merge the frame dimension into batch
|
||||||
|
bs, c, f, h, w = x.size()
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||||
|
|
||||||
|
# Pixel Unshuffle operation
|
||||||
|
x_unshuffled = self.pixel_unshuffle(x)
|
||||||
|
|
||||||
|
# Convolution operation
|
||||||
|
x_conv = self.conv(x_unshuffled)
|
||||||
|
|
||||||
|
# Feature extraction with residual blocks
|
||||||
|
out = self.residual_blocks(x_conv)
|
||||||
|
|
||||||
|
# Reshape to restore original bf dimension
|
||||||
|
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||||
|
|
||||||
|
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||||
|
out = out.permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_camera_coordinates(
|
||||||
|
self,
|
||||||
|
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||||
|
length: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
speed: float = 1/54,
|
||||||
|
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||||
|
):
|
||||||
|
if origin is None:
|
||||||
|
origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||||
|
coordinates = generate_camera_coordinates(direction, length, speed, origin)
|
||||||
|
plucker_embedding = process_pose_file(coordinates, width, height)
|
||||||
|
return plucker_embedding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
out = self.relu(self.conv1(x))
|
||||||
|
out = self.conv2(out)
|
||||||
|
out += residual
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Camera(object):
|
||||||
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||||
|
"""
|
||||||
|
def __init__(self, entry):
|
||||||
|
fx, fy, cx, cy = entry[1:5]
|
||||||
|
self.fx = fx
|
||||||
|
self.fy = fy
|
||||||
|
self.cx = cx
|
||||||
|
self.cy = cy
|
||||||
|
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
||||||
|
w2c_mat_4x4 = np.eye(4)
|
||||||
|
w2c_mat_4x4[:3, :] = w2c_mat
|
||||||
|
self.w2c_mat = w2c_mat_4x4
|
||||||
|
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
||||||
|
|
||||||
|
def get_relative_pose(cam_params):
|
||||||
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||||
|
"""
|
||||||
|
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||||
|
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||||
|
cam_to_origin = 0
|
||||||
|
target_cam_c2w = np.array([
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[0, 1, 0, -cam_to_origin],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 1]
|
||||||
|
])
|
||||||
|
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||||
|
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||||
|
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||||
|
return ret_poses
|
||||||
|
|
||||||
|
def custom_meshgrid(*args):
|
||||||
|
# torch>=2.0.0 only
|
||||||
|
return torch.meshgrid(*args, indexing='ij')
|
||||||
|
|
||||||
|
|
||||||
|
def ray_condition(K, c2w, H, W, device):
|
||||||
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||||
|
"""
|
||||||
|
# c2w: B, V, 4, 4
|
||||||
|
# K: B, V, 4
|
||||||
|
|
||||||
|
B = K.shape[0]
|
||||||
|
|
||||||
|
j, i = custom_meshgrid(
|
||||||
|
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||||
|
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||||
|
)
|
||||||
|
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||||
|
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||||
|
|
||||||
|
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||||
|
|
||||||
|
zs = torch.ones_like(i) # [B, HxW]
|
||||||
|
xs = (i - cx) / fx * zs
|
||||||
|
ys = (j - cy) / fy * zs
|
||||||
|
zs = zs.expand_as(ys)
|
||||||
|
|
||||||
|
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||||
|
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||||
|
|
||||||
|
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||||
|
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||||
|
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||||
|
# c2w @ dirctions
|
||||||
|
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
||||||
|
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||||
|
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||||
|
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||||
|
return plucker
|
||||||
|
|
||||||
|
|
||||||
|
def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||||
|
if return_poses:
|
||||||
|
return cam_params
|
||||||
|
else:
|
||||||
|
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||||
|
|
||||||
|
sample_wh_ratio = width / height
|
||||||
|
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||||
|
|
||||||
|
if pose_wh_ratio > sample_wh_ratio:
|
||||||
|
resized_ori_w = height * pose_wh_ratio
|
||||||
|
for cam_param in cam_params:
|
||||||
|
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||||
|
else:
|
||||||
|
resized_ori_h = width / pose_wh_ratio
|
||||||
|
for cam_param in cam_params:
|
||||||
|
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||||
|
|
||||||
|
intrinsic = np.asarray([[cam_param.fx * width,
|
||||||
|
cam_param.fy * height,
|
||||||
|
cam_param.cx * width,
|
||||||
|
cam_param.cy * height]
|
||||||
|
for cam_param in cam_params], dtype=np.float32)
|
||||||
|
|
||||||
|
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||||
|
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||||
|
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||||
|
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||||
|
plucker_embedding = plucker_embedding[None]
|
||||||
|
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||||
|
return plucker_embedding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_camera_coordinates(
|
||||||
|
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||||
|
length: int,
|
||||||
|
speed: float = 1/54,
|
||||||
|
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||||
|
):
|
||||||
|
coordinates = [list(origin)]
|
||||||
|
while len(coordinates) < length:
|
||||||
|
coor = coordinates[-1].copy()
|
||||||
|
if "Left" in direction:
|
||||||
|
coor[9] += speed
|
||||||
|
if "Right" in direction:
|
||||||
|
coor[9] -= speed
|
||||||
|
if "Up" in direction:
|
||||||
|
coor[13] += speed
|
||||||
|
if "Down" in direction:
|
||||||
|
coor[13] -= speed
|
||||||
|
coordinates.append(coor)
|
||||||
|
return coordinates
|
||||||
664
diffsynth/models/wan_video_dit.py
Normal file
664
diffsynth/models/wan_video_dit.py
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
from .wan_video_camera_controller import SimpleAdapter
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
FLASH_ATTN_2_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
SAGE_ATTN_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
|
||||||
|
if compatibility_mode:
|
||||||
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||||
|
if isinstance(x,tuple):
|
||||||
|
x = x[0]
|
||||||
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||||
|
elif FLASH_ATTN_2_AVAILABLE:
|
||||||
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
x = flash_attn.flash_attn_func(q, k, v)
|
||||||
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||||
|
elif SAGE_ATTN_AVAILABLE:
|
||||||
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
x = sageattn(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
else:
|
||||||
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||||
|
return (x * (1 + scale) + shift)
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||||
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
return x.to(position.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||||
|
# 3d rope precompute
|
||||||
|
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
||||||
|
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
||||||
|
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
||||||
|
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||||
|
# 1d rope precompute
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
||||||
|
[: (dim // 2)].double() / dim))
|
||||||
|
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def rope_apply(x, freqs, num_heads):
|
||||||
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||||
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
dtype = x.dtype
|
||||||
|
return self.norm(x.float()).to(dtype) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionModule(nn.Module):
|
||||||
|
def __init__(self, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, q, k, v):
|
||||||
|
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
self.norm_q = RMSNorm(dim, eps=eps)
|
||||||
|
self.norm_k = RMSNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
self.attn = AttentionModule(self.num_heads)
|
||||||
|
|
||||||
|
def forward(self, x, freqs):
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(x))
|
||||||
|
v = self.v(x)
|
||||||
|
q = rope_apply(q, freqs, self.num_heads)
|
||||||
|
k = rope_apply(k, freqs, self.num_heads)
|
||||||
|
x = self.attn(q, k, v)
|
||||||
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
self.norm_q = RMSNorm(dim, eps=eps)
|
||||||
|
self.norm_k = RMSNorm(dim, eps=eps)
|
||||||
|
self.has_image_input = has_image_input
|
||||||
|
if has_image_input:
|
||||||
|
self.k_img = nn.Linear(dim, dim)
|
||||||
|
self.v_img = nn.Linear(dim, dim)
|
||||||
|
self.norm_k_img = RMSNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
self.attn = AttentionModule(self.num_heads)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
|
if self.has_image_input:
|
||||||
|
img = y[:, :257]
|
||||||
|
ctx = y[:, 257:]
|
||||||
|
else:
|
||||||
|
ctx = y
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(ctx))
|
||||||
|
v = self.v(ctx)
|
||||||
|
x = self.attn(q, k, v)
|
||||||
|
if self.has_image_input:
|
||||||
|
k_img = self.norm_k_img(self.k_img(img))
|
||||||
|
v_img = self.v_img(img)
|
||||||
|
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
|
||||||
|
x = x + y
|
||||||
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GateModule(nn.Module):
|
||||||
|
def __init__(self,):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, gate, residual):
|
||||||
|
return x + gate * residual
|
||||||
|
|
||||||
|
class DiTBlock(nn.Module):
|
||||||
|
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
|
||||||
|
self.self_attn = SelfAttention(dim, num_heads, eps)
|
||||||
|
self.cross_attn = CrossAttention(
|
||||||
|
dim, num_heads, eps, has_image_input=has_image_input)
|
||||||
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||||
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||||
|
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
||||||
|
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
||||||
|
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
||||||
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||||
|
self.gate = GateModule()
|
||||||
|
|
||||||
|
def forward(self, x, context, t_mod, freqs):
|
||||||
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
|
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||||
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||||
|
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||||
|
x = x + self.cross_attn(self.norm3(x), context)
|
||||||
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
nn.LayerNorm(in_dim),
|
||||||
|
nn.Linear(in_dim, in_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(in_dim, out_dim),
|
||||||
|
nn.LayerNorm(out_dim)
|
||||||
|
)
|
||||||
|
self.has_pos_emb = has_pos_emb
|
||||||
|
if has_pos_emb:
|
||||||
|
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.has_pos_emb:
|
||||||
|
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Head(nn.Module):
|
||||||
|
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||||
|
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
||||||
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||||
|
|
||||||
|
def forward(self, x, t_mod):
|
||||||
|
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||||
|
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
in_dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
text_dim: int,
|
||||||
|
freq_dim: int,
|
||||||
|
eps: float,
|
||||||
|
patch_size: Tuple[int, int, int],
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int,
|
||||||
|
has_image_input: bool,
|
||||||
|
has_image_pos_emb: bool = False,
|
||||||
|
has_ref_conv: bool = False,
|
||||||
|
add_control_adapter: bool = False,
|
||||||
|
in_dim_control_adapter: int = 24,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.has_image_input = has_image_input
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv3d(
|
||||||
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
self.text_embedding = nn.Sequential(
|
||||||
|
nn.Linear(text_dim, dim),
|
||||||
|
nn.GELU(approximate='tanh'),
|
||||||
|
nn.Linear(dim, dim)
|
||||||
|
)
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(freq_dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim)
|
||||||
|
)
|
||||||
|
self.time_projection = nn.Sequential(
|
||||||
|
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.head = Head(dim, out_dim, patch_size, eps)
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||||
|
|
||||||
|
if has_image_input:
|
||||||
|
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||||
|
if has_ref_conv:
|
||||||
|
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||||
|
self.has_image_pos_emb = has_image_pos_emb
|
||||||
|
self.has_ref_conv = has_ref_conv
|
||||||
|
if add_control_adapter:
|
||||||
|
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||||
|
else:
|
||||||
|
self.control_adapter = None
|
||||||
|
|
||||||
|
def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None):
|
||||||
|
x = self.patch_embedding(x)
|
||||||
|
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||||
|
y_camera = self.control_adapter(control_camera_latents_input)
|
||||||
|
x = [u + v for u, v in zip(x, y_camera)]
|
||||||
|
x = x[0].unsqueeze(0)
|
||||||
|
grid_size = x.shape[2:]
|
||||||
|
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||||
|
return x, grid_size # x, grid_size: (f, h, w)
|
||||||
|
|
||||||
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||||
|
return rearrange(
|
||||||
|
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||||
|
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
||||||
|
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = self.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = self.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
x = self.head(x, t)
|
||||||
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanModelStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||||
|
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||||
|
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||||
|
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||||
|
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||||
|
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||||
|
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||||
|
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||||
|
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||||
|
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||||
|
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||||
|
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||||
|
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||||
|
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||||
|
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||||
|
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||||
|
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||||
|
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||||
|
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||||
|
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||||
|
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||||
|
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||||
|
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||||
|
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||||
|
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||||
|
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||||
|
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||||
|
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||||
|
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||||
|
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||||
|
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||||
|
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||||
|
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||||
|
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||||
|
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||||
|
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||||
|
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||||
|
"patch_embedding.bias": "patch_embedding.bias",
|
||||||
|
"patch_embedding.weight": "patch_embedding.weight",
|
||||||
|
"scale_shift_table": "head.modulation",
|
||||||
|
"proj_out.bias": "head.head.bias",
|
||||||
|
"proj_out.weight": "head.head.weight",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in rename_dict:
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
else:
|
||||||
|
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||||
|
if name_ in rename_dict:
|
||||||
|
name_ = rename_dict[name_]
|
||||||
|
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||||
|
config = {
|
||||||
|
"model_type": "t2v",
|
||||||
|
"patch_size": (1, 2, 2),
|
||||||
|
"text_len": 512,
|
||||||
|
"in_dim": 16,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"window_size": (-1, -1),
|
||||||
|
"qk_norm": True,
|
||||||
|
"cross_attn_norm": True,
|
||||||
|
"eps": 1e-6,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {}
|
||||||
|
return state_dict_, config
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||||
|
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||||
|
config = {
|
||||||
|
"has_image_input": False,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 16,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||||
|
config = {
|
||||||
|
"has_image_input": False,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 16,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||||
|
# 1.3B PAI control
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||||
|
# 14B PAI control
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_image_pos_emb": True
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||||
|
# 1.3B PAI control v1.1
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_ref_conv": True
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||||
|
# 14B PAI control v1.1
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 48,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_ref_conv": True
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||||
|
# 1.3B PAI control-camera v1.1
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 32,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_ref_conv": False,
|
||||||
|
"add_control_adapter": True,
|
||||||
|
"in_dim_control_adapter": 24,
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||||
|
# 14B PAI control-camera v1.1
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 32,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6,
|
||||||
|
"has_ref_conv": False,
|
||||||
|
"add_control_adapter": True,
|
||||||
|
"in_dim_control_adapter": 24,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {}
|
||||||
|
return state_dict, config
|
||||||
902
diffsynth/models/wan_video_image_encoder.py
Normal file
902
diffsynth/models/wan_video_image_encoder.py
Normal file
@@ -0,0 +1,902 @@
|
|||||||
|
"""
|
||||||
|
Concise re-implementation of
|
||||||
|
``https://github.com/openai/CLIP'' and
|
||||||
|
``https://github.com/mlfoundations/open_clip''.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from .wan_video_dit import flash_attention
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
"""
|
||||||
|
x: [B, L, C].
|
||||||
|
"""
|
||||||
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||||
|
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||||
|
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
p = self.dropout.p if self.training else 0.0
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
||||||
|
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.o(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.post_norm = post_norm
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
||||||
|
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
||||||
|
nn.Dropout(dropout))
|
||||||
|
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
if self.post_norm:
|
||||||
|
x = self.norm1(x + self.attn(x, mask))
|
||||||
|
x = self.norm2(x + self.ffn(x))
|
||||||
|
else:
|
||||||
|
x = x + self.attn(self.norm1(x), mask)
|
||||||
|
x = x + self.ffn(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class XLMRoberta(nn.Module):
|
||||||
|
"""
|
||||||
|
XLMRobertaModel with no pooler and no LM head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size=250002,
|
||||||
|
max_seq_len=514,
|
||||||
|
type_size=1,
|
||||||
|
pad_id=1,
|
||||||
|
dim=1024,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=24,
|
||||||
|
post_norm=True,
|
||||||
|
dropout=0.1,
|
||||||
|
eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.type_size = type_size
|
||||||
|
self.pad_id = pad_id
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.post_norm = post_norm
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
||||||
|
self.type_embedding = nn.Embedding(type_size, dim)
|
||||||
|
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# norm layer
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, ids):
|
||||||
|
"""
|
||||||
|
ids: [B, L] of torch.LongTensor.
|
||||||
|
"""
|
||||||
|
b, s = ids.shape
|
||||||
|
mask = ids.ne(self.pad_id).long()
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.token_embedding(ids) + \
|
||||||
|
self.type_embedding(torch.zeros_like(ids)) + \
|
||||||
|
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
||||||
|
if self.post_norm:
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
mask = torch.where(
|
||||||
|
mask.view(b, 1, 1, s).gt(0), 0.0,
|
||||||
|
torch.finfo(x.dtype).min)
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, mask)
|
||||||
|
|
||||||
|
# output
|
||||||
|
if not self.post_norm:
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def xlm_roberta_large(pretrained=False,
|
||||||
|
return_tokenizer=False,
|
||||||
|
device='cpu',
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
XLMRobertaLarge adapted from Huggingface.
|
||||||
|
"""
|
||||||
|
# params
|
||||||
|
cfg = dict(
|
||||||
|
vocab_size=250002,
|
||||||
|
max_seq_len=514,
|
||||||
|
type_size=1,
|
||||||
|
pad_id=1,
|
||||||
|
dim=1024,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=24,
|
||||||
|
post_norm=True,
|
||||||
|
dropout=0.1,
|
||||||
|
eps=1e-5)
|
||||||
|
cfg.update(**kwargs)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
if pretrained:
|
||||||
|
from sora import DOWNLOAD_TO_CACHE
|
||||||
|
|
||||||
|
# init a meta model
|
||||||
|
with torch.device('meta'):
|
||||||
|
model = XLMRoberta(**cfg)
|
||||||
|
|
||||||
|
# load checkpoint
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
|
||||||
|
map_location=device),
|
||||||
|
assign=True)
|
||||||
|
else:
|
||||||
|
# init a model on device
|
||||||
|
with torch.device(device):
|
||||||
|
model = XLMRoberta(**cfg)
|
||||||
|
|
||||||
|
# init tokenizer
|
||||||
|
if return_tokenizer:
|
||||||
|
from sora.data import HuggingfaceTokenizer
|
||||||
|
tokenizer = HuggingfaceTokenizer(
|
||||||
|
name='xlm-roberta-large',
|
||||||
|
seq_len=model.text_len,
|
||||||
|
clean='whitespace')
|
||||||
|
return model, tokenizer
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def pos_interpolate(pos, seq_len):
|
||||||
|
if pos.size(1) == seq_len:
|
||||||
|
return pos
|
||||||
|
else:
|
||||||
|
src_grid = int(math.sqrt(pos.size(1)))
|
||||||
|
tar_grid = int(math.sqrt(seq_len))
|
||||||
|
n = pos.size(1) - src_grid * src_grid
|
||||||
|
return torch.cat([
|
||||||
|
pos[:, :n],
|
||||||
|
F.interpolate(
|
||||||
|
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
||||||
|
0, 3, 1, 2),
|
||||||
|
size=(tar_grid, tar_grid),
|
||||||
|
mode='bicubic',
|
||||||
|
align_corners=False).flatten(2).transpose(1, 2)
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
causal=False,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.causal = causal
|
||||||
|
self.attn_dropout = attn_dropout
|
||||||
|
self.proj_dropout = proj_dropout
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: [B, L, C].
|
||||||
|
"""
|
||||||
|
# compute query, key, value
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
x = F.dropout(x, self.proj_dropout, self.training)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, mid_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mid_dim = mid_dim
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.fc1 = nn.Linear(dim, mid_dim)
|
||||||
|
self.fc2 = nn.Linear(dim, mid_dim)
|
||||||
|
self.fc3 = nn.Linear(mid_dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.silu(self.fc1(x)) * self.fc2(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
mlp_ratio,
|
||||||
|
num_heads,
|
||||||
|
post_norm=False,
|
||||||
|
causal=False,
|
||||||
|
activation='quick_gelu',
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
norm_eps=1e-5):
|
||||||
|
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.post_norm = post_norm
|
||||||
|
self.causal = causal
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
||||||
|
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
||||||
|
proj_dropout)
|
||||||
|
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
||||||
|
if activation == 'swi_glu':
|
||||||
|
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
||||||
|
else:
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||||
|
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||||
|
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.post_norm:
|
||||||
|
x = x + self.norm1(self.attn(x))
|
||||||
|
x = x + self.norm2(self.mlp(x))
|
||||||
|
else:
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
mlp_ratio,
|
||||||
|
num_heads,
|
||||||
|
activation='gelu',
|
||||||
|
proj_dropout=0.0,
|
||||||
|
norm_eps=1e-5):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.proj_dropout = proj_dropout
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
gain = 1.0 / math.sqrt(dim)
|
||||||
|
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||||
|
self.to_q = nn.Linear(dim, dim)
|
||||||
|
self.to_kv = nn.Linear(dim, dim * 2)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.norm = LayerNorm(dim, eps=norm_eps)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||||
|
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||||
|
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: [B, L, C].
|
||||||
|
"""
|
||||||
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||||
|
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
|
x = x.reshape(b, 1, c)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
x = F.dropout(x, self.proj_dropout, self.training)
|
||||||
|
|
||||||
|
# mlp
|
||||||
|
x = x + self.mlp(self.norm(x))
|
||||||
|
return x[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
dim=768,
|
||||||
|
mlp_ratio=4,
|
||||||
|
out_dim=512,
|
||||||
|
num_heads=12,
|
||||||
|
num_layers=12,
|
||||||
|
pool_type='token',
|
||||||
|
pre_norm=True,
|
||||||
|
post_norm=False,
|
||||||
|
activation='quick_gelu',
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
norm_eps=1e-5):
|
||||||
|
if image_size % patch_size != 0:
|
||||||
|
print(
|
||||||
|
'[WARNING] image_size is not divisible by patch_size',
|
||||||
|
flush=True)
|
||||||
|
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
||||||
|
out_dim = out_dim or dim
|
||||||
|
super().__init__()
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_patches = (image_size // patch_size)**2
|
||||||
|
self.dim = dim
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.pool_type = pool_type
|
||||||
|
self.post_norm = post_norm
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
gain = 1.0 / math.sqrt(dim)
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
3,
|
||||||
|
dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=not pre_norm)
|
||||||
|
if pool_type in ('token', 'token_fc'):
|
||||||
|
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||||
|
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
||||||
|
1, self.num_patches +
|
||||||
|
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
||||||
|
self.dropout = nn.Dropout(embedding_dropout)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
||||||
|
self.transformer = nn.Sequential(*[
|
||||||
|
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
||||||
|
activation, attn_dropout, proj_dropout, norm_eps)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
# head
|
||||||
|
if pool_type == 'token':
|
||||||
|
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||||
|
elif pool_type == 'token_fc':
|
||||||
|
self.head = nn.Linear(dim, out_dim)
|
||||||
|
elif pool_type == 'attn_pool':
|
||||||
|
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
||||||
|
proj_dropout, norm_eps)
|
||||||
|
|
||||||
|
def forward(self, x, interpolation=False, use_31_block=False):
|
||||||
|
b = x.size(0)
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
||||||
|
if self.pool_type in ('token', 'token_fc'):
|
||||||
|
x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
|
||||||
|
if interpolation:
|
||||||
|
e = pos_interpolate(self.pos_embedding, x.size(1))
|
||||||
|
else:
|
||||||
|
e = self.pos_embedding
|
||||||
|
e = e.to(dtype=x.dtype, device=x.device)
|
||||||
|
x = self.dropout(x + e)
|
||||||
|
if self.pre_norm is not None:
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
if use_31_block:
|
||||||
|
x = self.transformer[:-1](x)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
x = self.transformer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dim=512,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
vision_dim=768,
|
||||||
|
vision_mlp_ratio=4,
|
||||||
|
vision_heads=12,
|
||||||
|
vision_layers=12,
|
||||||
|
vision_pool='token',
|
||||||
|
vision_pre_norm=True,
|
||||||
|
vision_post_norm=False,
|
||||||
|
vocab_size=49408,
|
||||||
|
text_len=77,
|
||||||
|
text_dim=512,
|
||||||
|
text_mlp_ratio=4,
|
||||||
|
text_heads=8,
|
||||||
|
text_layers=12,
|
||||||
|
text_causal=True,
|
||||||
|
text_pool='argmax',
|
||||||
|
text_head_bias=False,
|
||||||
|
logit_bias=None,
|
||||||
|
activation='quick_gelu',
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
norm_eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.vision_dim = vision_dim
|
||||||
|
self.vision_mlp_ratio = vision_mlp_ratio
|
||||||
|
self.vision_heads = vision_heads
|
||||||
|
self.vision_layers = vision_layers
|
||||||
|
self.vision_pool = vision_pool
|
||||||
|
self.vision_pre_norm = vision_pre_norm
|
||||||
|
self.vision_post_norm = vision_post_norm
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.text_len = text_len
|
||||||
|
self.text_dim = text_dim
|
||||||
|
self.text_mlp_ratio = text_mlp_ratio
|
||||||
|
self.text_heads = text_heads
|
||||||
|
self.text_layers = text_layers
|
||||||
|
self.text_causal = text_causal
|
||||||
|
self.text_pool = text_pool
|
||||||
|
self.text_head_bias = text_head_bias
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
# models
|
||||||
|
self.visual = VisionTransformer(
|
||||||
|
image_size=image_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
dim=vision_dim,
|
||||||
|
mlp_ratio=vision_mlp_ratio,
|
||||||
|
out_dim=embed_dim,
|
||||||
|
num_heads=vision_heads,
|
||||||
|
num_layers=vision_layers,
|
||||||
|
pool_type=vision_pool,
|
||||||
|
pre_norm=vision_pre_norm,
|
||||||
|
post_norm=vision_post_norm,
|
||||||
|
activation=activation,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
proj_dropout=proj_dropout,
|
||||||
|
embedding_dropout=embedding_dropout,
|
||||||
|
norm_eps=norm_eps)
|
||||||
|
self.textual = TextTransformer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
text_len=text_len,
|
||||||
|
dim=text_dim,
|
||||||
|
mlp_ratio=text_mlp_ratio,
|
||||||
|
out_dim=embed_dim,
|
||||||
|
num_heads=text_heads,
|
||||||
|
num_layers=text_layers,
|
||||||
|
causal=text_causal,
|
||||||
|
pool_type=text_pool,
|
||||||
|
head_bias=text_head_bias,
|
||||||
|
activation=activation,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
proj_dropout=proj_dropout,
|
||||||
|
embedding_dropout=embedding_dropout,
|
||||||
|
norm_eps=norm_eps)
|
||||||
|
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||||
|
if logit_bias is not None:
|
||||||
|
self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
|
||||||
|
|
||||||
|
# initialize weights
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def forward(self, imgs, txt_ids):
|
||||||
|
"""
|
||||||
|
imgs: [B, 3, H, W] of torch.float32.
|
||||||
|
- mean: [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
- std: [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
|
||||||
|
"""
|
||||||
|
xi = self.visual(imgs)
|
||||||
|
xt = self.textual(txt_ids)
|
||||||
|
return xi, xt
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# embeddings
|
||||||
|
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
||||||
|
nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
|
||||||
|
|
||||||
|
# attentions
|
||||||
|
for modality in ['visual', 'textual']:
|
||||||
|
dim = self.vision_dim if modality == 'visual' else self.text_dim
|
||||||
|
transformer = getattr(self, modality).transformer
|
||||||
|
proj_gain = (1.0 / math.sqrt(dim)) * (
|
||||||
|
1.0 / math.sqrt(2 * len(transformer)))
|
||||||
|
attn_gain = 1.0 / math.sqrt(dim)
|
||||||
|
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
||||||
|
for block in transformer:
|
||||||
|
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
||||||
|
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
||||||
|
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
||||||
|
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
||||||
|
|
||||||
|
def param_groups(self):
|
||||||
|
groups = [{
|
||||||
|
'params': [
|
||||||
|
p for n, p in self.named_parameters()
|
||||||
|
if 'norm' in n or n.endswith('bias')
|
||||||
|
],
|
||||||
|
'weight_decay': 0.0
|
||||||
|
}, {
|
||||||
|
'params': [
|
||||||
|
p for n, p in self.named_parameters()
|
||||||
|
if not ('norm' in n or n.endswith('bias'))
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
class XLMRobertaWithHead(XLMRoberta):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.out_dim = kwargs.pop('out_dim')
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# head
|
||||||
|
mid_dim = (self.dim + self.out_dim) // 2
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
||||||
|
nn.Linear(mid_dim, self.out_dim, bias=False))
|
||||||
|
|
||||||
|
def forward(self, ids):
|
||||||
|
# xlm-roberta
|
||||||
|
x = super().forward(ids)
|
||||||
|
|
||||||
|
# average pooling
|
||||||
|
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
||||||
|
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class XLMRobertaCLIP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dim=1024,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=14,
|
||||||
|
vision_dim=1280,
|
||||||
|
vision_mlp_ratio=4,
|
||||||
|
vision_heads=16,
|
||||||
|
vision_layers=32,
|
||||||
|
vision_pool='token',
|
||||||
|
vision_pre_norm=True,
|
||||||
|
vision_post_norm=False,
|
||||||
|
activation='gelu',
|
||||||
|
vocab_size=250002,
|
||||||
|
max_text_len=514,
|
||||||
|
type_size=1,
|
||||||
|
pad_id=1,
|
||||||
|
text_dim=1024,
|
||||||
|
text_heads=16,
|
||||||
|
text_layers=24,
|
||||||
|
text_post_norm=True,
|
||||||
|
text_dropout=0.1,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
norm_eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.vision_dim = vision_dim
|
||||||
|
self.vision_mlp_ratio = vision_mlp_ratio
|
||||||
|
self.vision_heads = vision_heads
|
||||||
|
self.vision_layers = vision_layers
|
||||||
|
self.vision_pre_norm = vision_pre_norm
|
||||||
|
self.vision_post_norm = vision_post_norm
|
||||||
|
self.activation = activation
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_text_len = max_text_len
|
||||||
|
self.type_size = type_size
|
||||||
|
self.pad_id = pad_id
|
||||||
|
self.text_dim = text_dim
|
||||||
|
self.text_heads = text_heads
|
||||||
|
self.text_layers = text_layers
|
||||||
|
self.text_post_norm = text_post_norm
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
# models
|
||||||
|
self.visual = VisionTransformer(
|
||||||
|
image_size=image_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
dim=vision_dim,
|
||||||
|
mlp_ratio=vision_mlp_ratio,
|
||||||
|
out_dim=embed_dim,
|
||||||
|
num_heads=vision_heads,
|
||||||
|
num_layers=vision_layers,
|
||||||
|
pool_type=vision_pool,
|
||||||
|
pre_norm=vision_pre_norm,
|
||||||
|
post_norm=vision_post_norm,
|
||||||
|
activation=activation,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
proj_dropout=proj_dropout,
|
||||||
|
embedding_dropout=embedding_dropout,
|
||||||
|
norm_eps=norm_eps)
|
||||||
|
self.textual = None
|
||||||
|
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||||
|
|
||||||
|
def forward(self, imgs, txt_ids):
|
||||||
|
"""
|
||||||
|
imgs: [B, 3, H, W] of torch.float32.
|
||||||
|
- mean: [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
- std: [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
txt_ids: [B, L] of torch.long.
|
||||||
|
Encoded by data.CLIPTokenizer.
|
||||||
|
"""
|
||||||
|
xi = self.visual(imgs)
|
||||||
|
xt = self.textual(txt_ids)
|
||||||
|
return xi, xt
|
||||||
|
|
||||||
|
def param_groups(self):
|
||||||
|
groups = [{
|
||||||
|
'params': [
|
||||||
|
p for n, p in self.named_parameters()
|
||||||
|
if 'norm' in n or n.endswith('bias')
|
||||||
|
],
|
||||||
|
'weight_decay': 0.0
|
||||||
|
}, {
|
||||||
|
'params': [
|
||||||
|
p for n, p in self.named_parameters()
|
||||||
|
if not ('norm' in n or n.endswith('bias'))
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
def _clip(pretrained=False,
|
||||||
|
pretrained_name=None,
|
||||||
|
model_cls=CLIP,
|
||||||
|
return_transforms=False,
|
||||||
|
return_tokenizer=False,
|
||||||
|
tokenizer_padding='eos',
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cpu',
|
||||||
|
**kwargs):
|
||||||
|
# init model
|
||||||
|
if pretrained and pretrained_name:
|
||||||
|
from sora import BUCKET, DOWNLOAD_TO_CACHE
|
||||||
|
|
||||||
|
# init a meta model
|
||||||
|
with torch.device('meta'):
|
||||||
|
model = model_cls(**kwargs)
|
||||||
|
|
||||||
|
# checkpoint path
|
||||||
|
checkpoint = f'models/clip/{pretrained_name}'
|
||||||
|
if dtype in (torch.float16, torch.bfloat16):
|
||||||
|
suffix = '-' + {
|
||||||
|
torch.float16: 'fp16',
|
||||||
|
torch.bfloat16: 'bf16'
|
||||||
|
}[dtype]
|
||||||
|
if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
|
||||||
|
checkpoint = f'{checkpoint}{suffix}'
|
||||||
|
checkpoint += '.pth'
|
||||||
|
|
||||||
|
# load
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
|
||||||
|
assign=True,
|
||||||
|
strict=False)
|
||||||
|
else:
|
||||||
|
# init a model on device
|
||||||
|
with torch.device(device):
|
||||||
|
model = model_cls(**kwargs)
|
||||||
|
|
||||||
|
# set device
|
||||||
|
output = (model,)
|
||||||
|
|
||||||
|
# init transforms
|
||||||
|
if return_transforms:
|
||||||
|
# mean and std
|
||||||
|
if 'siglip' in pretrained_name.lower():
|
||||||
|
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
||||||
|
else:
|
||||||
|
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
|
||||||
|
# transforms
|
||||||
|
transforms = T.Compose([
|
||||||
|
T.Resize((model.image_size, model.image_size),
|
||||||
|
interpolation=T.InterpolationMode.BICUBIC),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
output += (transforms,)
|
||||||
|
|
||||||
|
# init tokenizer
|
||||||
|
if return_tokenizer:
|
||||||
|
from sora import data
|
||||||
|
if 'siglip' in pretrained_name.lower():
|
||||||
|
tokenizer = data.HuggingfaceTokenizer(
|
||||||
|
name=f'timm/{pretrained_name}',
|
||||||
|
seq_len=model.text_len,
|
||||||
|
clean='canonicalize')
|
||||||
|
elif 'xlm' in pretrained_name.lower():
|
||||||
|
tokenizer = data.HuggingfaceTokenizer(
|
||||||
|
name='xlm-roberta-large',
|
||||||
|
seq_len=model.max_text_len - 2,
|
||||||
|
clean='whitespace')
|
||||||
|
elif 'mba' in pretrained_name.lower():
|
||||||
|
tokenizer = data.HuggingfaceTokenizer(
|
||||||
|
name='facebook/xlm-roberta-xl',
|
||||||
|
seq_len=model.max_text_len - 2,
|
||||||
|
clean='whitespace')
|
||||||
|
else:
|
||||||
|
tokenizer = data.CLIPTokenizer(
|
||||||
|
seq_len=model.text_len, padding=tokenizer_padding)
|
||||||
|
output += (tokenizer,)
|
||||||
|
return output[0] if len(output) == 1 else output
|
||||||
|
|
||||||
|
|
||||||
|
def clip_xlm_roberta_vit_h_14(
|
||||||
|
pretrained=False,
|
||||||
|
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
||||||
|
**kwargs):
|
||||||
|
cfg = dict(
|
||||||
|
embed_dim=1024,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=14,
|
||||||
|
vision_dim=1280,
|
||||||
|
vision_mlp_ratio=4,
|
||||||
|
vision_heads=16,
|
||||||
|
vision_layers=32,
|
||||||
|
vision_pool='token',
|
||||||
|
activation='gelu',
|
||||||
|
vocab_size=250002,
|
||||||
|
max_text_len=514,
|
||||||
|
type_size=1,
|
||||||
|
pad_id=1,
|
||||||
|
text_dim=1024,
|
||||||
|
text_heads=16,
|
||||||
|
text_layers=24,
|
||||||
|
text_post_norm=True,
|
||||||
|
text_dropout=0.1,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0)
|
||||||
|
cfg.update(**kwargs)
|
||||||
|
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class WanImageEncoder(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# init model
|
||||||
|
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
||||||
|
pretrained=False,
|
||||||
|
return_transforms=True,
|
||||||
|
return_tokenizer=False,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
def encode_image(self, videos):
|
||||||
|
# preprocess
|
||||||
|
size = (self.model.image_size,) * 2
|
||||||
|
videos = torch.cat([
|
||||||
|
F.interpolate(
|
||||||
|
u,
|
||||||
|
size=size,
|
||||||
|
mode='bicubic',
|
||||||
|
align_corners=False) for u in videos
|
||||||
|
])
|
||||||
|
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||||
|
|
||||||
|
# forward
|
||||||
|
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||||
|
videos = videos.to(dtype)
|
||||||
|
out = self.model.visual(videos, use_31_block=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanImageEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanImageEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("textual."):
|
||||||
|
continue
|
||||||
|
name = "model." + name
|
||||||
|
state_dict_[name] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
44
diffsynth/models/wan_video_motion_controller.py
Normal file
44
diffsynth/models/wan_video_motion_controller.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .wan_video_dit import sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanMotionControllerModel(torch.nn.Module):
|
||||||
|
def __init__(self, freq_dim=256, dim=1536):
|
||||||
|
super().__init__()
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.linear = nn.Sequential(
|
||||||
|
nn.Linear(freq_dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim * 6),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, motion_bucket_id):
|
||||||
|
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||||
|
emb = self.linear(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
state_dict = self.linear[-1].state_dict()
|
||||||
|
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||||
|
self.linear[-1].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanMotionControllerModelDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanMotionControllerModelDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
@@ -206,7 +206,7 @@ def init_weights(m):
|
|||||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||||
|
|
||||||
|
|
||||||
class WanXTextEncoder(torch.nn.Module):
|
class WanTextEncoder(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab=256384,
|
vocab=256384,
|
||||||
@@ -218,7 +218,7 @@ class WanXTextEncoder(torch.nn.Module):
|
|||||||
num_buckets=32,
|
num_buckets=32,
|
||||||
shared_pos=False,
|
shared_pos=False,
|
||||||
dropout=0.1):
|
dropout=0.1):
|
||||||
super(WanXTextEncoder, self).__init__()
|
super(WanTextEncoder, self).__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.dim_attn = dim_attn
|
self.dim_attn = dim_attn
|
||||||
self.dim_ffn = dim_ffn
|
self.dim_ffn = dim_ffn
|
||||||
@@ -252,3 +252,18 @@ class WanXTextEncoder(torch.nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanTextEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
113
diffsynth/models/wan_video_vace.py
Normal file
113
diffsynth/models/wan_video_vace.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
from .wan_video_dit import DiTBlock
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
|
||||||
|
class VaceWanAttentionBlock(DiTBlock):
|
||||||
|
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||||
|
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = torch.nn.Linear(self.dim, self.dim)
|
||||||
|
self.after_proj = torch.nn.Linear(self.dim, self.dim)
|
||||||
|
|
||||||
|
def forward(self, c, x, context, t_mod, freqs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c))
|
||||||
|
c = all_c.pop(-1)
|
||||||
|
c = super().forward(c, context, t_mod, freqs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
all_c += [c_skip, c]
|
||||||
|
c = torch.stack(all_c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||||
|
vace_in_dim=96,
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
has_image_input=False,
|
||||||
|
dim=1536,
|
||||||
|
num_heads=12,
|
||||||
|
ffn_dim=8960,
|
||||||
|
eps=1e-6,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vace_layers = vace_layers
|
||||||
|
self.vace_in_dim = vace_in_dim
|
||||||
|
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||||
|
|
||||||
|
# vace blocks
|
||||||
|
self.vace_blocks = torch.nn.ModuleList([
|
||||||
|
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||||
|
for i in self.vace_layers
|
||||||
|
])
|
||||||
|
|
||||||
|
# vace patch embeddings
|
||||||
|
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
):
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
c = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
c, x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
c = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
c, x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
c = block(c, x, context, t_mod, freqs)
|
||||||
|
hints = torch.unbind(c)[:-1]
|
||||||
|
return hints
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return VaceWanModelDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModelDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
|
||||||
|
if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B
|
||||||
|
config = {
|
||||||
|
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
|
||||||
|
"vace_in_dim": 96,
|
||||||
|
"patch_size": (1, 2, 2),
|
||||||
|
"has_image_input": False,
|
||||||
|
"dim": 5120,
|
||||||
|
"num_heads": 40,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"eps": 1e-06,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {}
|
||||||
|
return state_dict_, config
|
||||||
@@ -7,6 +7,15 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|
||||||
|
def check_is_instance(model, module_class):
|
||||||
|
if isinstance(model, module_class):
|
||||||
|
return True
|
||||||
|
if hasattr(model, "module") and isinstance(model.module, module_class):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def block_causal_mask(x, block_size):
|
def block_causal_mask(x, block_size):
|
||||||
# params
|
# params
|
||||||
b, n, s, _, device = *x.size(), x.device
|
b, n, s, _, device = *x.size(), x.device
|
||||||
@@ -205,7 +214,7 @@ class ResidualBlock(nn.Module):
|
|||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
h = self.shortcut(x)
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
@@ -342,14 +351,14 @@ class Encoder3d(nn.Module):
|
|||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## head
|
## head
|
||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
@@ -440,7 +449,7 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
@@ -454,7 +463,7 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
## head
|
## head
|
||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
@@ -475,7 +484,7 @@ class Decoder3d(nn.Module):
|
|||||||
def count_conv3d(model):
|
def count_conv3d(model):
|
||||||
count = 0
|
count = 0
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, CausalConv3d):
|
if check_is_instance(m, CausalConv3d):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@@ -587,7 +596,7 @@ class VideoVAE_(nn.Module):
|
|||||||
self._enc_feat_map = [None] * self._enc_conv_num
|
self._enc_feat_map = [None] * self._enc_conv_num
|
||||||
|
|
||||||
|
|
||||||
class WanXVideoVAE(nn.Module):
|
class WanVideoVAE(nn.Module):
|
||||||
|
|
||||||
def __init__(self, z_dim=16):
|
def __init__(self, z_dim=16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -679,7 +688,7 @@ class WanXVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float().clamp_(-1, 1)
|
values = values.clamp_(-1, 1)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@@ -731,64 +740,61 @@ class WanXVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float()
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def single_encode(self, video, device):
|
def single_encode(self, video, device):
|
||||||
video = video.to(device)
|
video = video.to(device)
|
||||||
x = self.model.encode(video, self.scale)
|
x = self.model.encode(video, self.scale)
|
||||||
return x.float()
|
return x
|
||||||
|
|
||||||
|
|
||||||
def single_decode(self, hidden_state, device):
|
def single_decode(self, hidden_state, device):
|
||||||
hidden_state = hidden_state.to(device)
|
hidden_state = hidden_state.to(device)
|
||||||
video = self.model.decode(hidden_state, self.scale)
|
video = self.model.decode(hidden_state, self.scale)
|
||||||
return video.float().clamp_(-1, 1)
|
return video.clamp_(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
|
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
|
||||||
videos = [video.to("cpu") for video in videos]
|
videos = [video.to("cpu") for video in videos]
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
video = video.unsqueeze(0)
|
video = video.unsqueeze(0)
|
||||||
if tiled:
|
if tiled:
|
||||||
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
|
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
||||||
|
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
||||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||||
else:
|
else:
|
||||||
hidden_state = self.single_encode(video, device)
|
hidden_state = self.single_encode(video, device)
|
||||||
hidden_state = hidden_state.squeeze(0)
|
hidden_state = hidden_state.squeeze(0)
|
||||||
hidden_states.append(hidden_state)
|
hidden_states.append(hidden_state)
|
||||||
|
hidden_states = torch.stack(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
if tiled:
|
||||||
videos = []
|
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||||
for hidden_state in hidden_states:
|
else:
|
||||||
hidden_state = hidden_state.unsqueeze(0)
|
video = self.single_decode(hidden_states, device)
|
||||||
if tiled:
|
return video
|
||||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
|
||||||
else:
|
|
||||||
video = self.single_decode(hidden_state, device)
|
|
||||||
video = video.squeeze(0)
|
|
||||||
videos.append(video)
|
|
||||||
return videos
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def state_dict_converter():
|
def state_dict_converter():
|
||||||
return WanXVideoVAEStateDictConverter()
|
return WanVideoVAEStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
class WanXVideoVAEStateDictConverter:
|
class WanVideoVAEStateDictConverter:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
for name in state_dict['model_state']:
|
if 'model_state' in state_dict:
|
||||||
state_dict_['model.' + name] = state_dict['model_state'][name]
|
state_dict = state_dict['model_state']
|
||||||
|
for name in state_dict:
|
||||||
|
state_dict_['model.' + name] = state_dict[name]
|
||||||
return state_dict_
|
return state_dict_
|
||||||
@@ -11,4 +11,5 @@ from .omnigen_image import OmnigenImagePipeline
|
|||||||
from .pipeline_runner import SDVideoPipelineRunner
|
from .pipeline_runner import SDVideoPipelineRunner
|
||||||
from .hunyuan_video import HunyuanVideoPipeline
|
from .hunyuan_video import HunyuanVideoPipeline
|
||||||
from .step_video import StepVideoPipeline
|
from .step_video import StepVideoPipeline
|
||||||
|
from .wan_video import WanVideoPipeline
|
||||||
KolorsImagePipeline = SDXLImagePipeline
|
KolorsImagePipeline = SDXLImagePipeline
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||||
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
from ..prompters import FluxPrompter
|
from ..prompters import FluxPrompter
|
||||||
from ..schedulers import FlowMatchScheduler
|
from ..schedulers import FlowMatchScheduler
|
||||||
@@ -31,105 +32,113 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.controlnet: FluxMultiControlNetManager = None
|
self.controlnet: FluxMultiControlNetManager = None
|
||||||
self.ipadapter: FluxIpAdapter = None
|
self.ipadapter: FluxIpAdapter = None
|
||||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
self.infinityou_processor: InfinitYou = None
|
||||||
|
self.qwenvl = None
|
||||||
|
self.step1x_connector: Qwen2Connector = None
|
||||||
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector']
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
if self.text_encoder_1 is not None:
|
||||||
enable_vram_management(
|
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||||
self.text_encoder_1,
|
enable_vram_management(
|
||||||
module_map = {
|
self.text_encoder_1,
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
module_map = {
|
||||||
torch.nn.Embedding: AutoWrappedModule,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
},
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
module_config = dict(
|
},
|
||||||
offload_dtype=dtype,
|
module_config = dict(
|
||||||
offload_device="cpu",
|
offload_dtype=dtype,
|
||||||
onload_dtype=dtype,
|
offload_device="cpu",
|
||||||
onload_device="cpu",
|
onload_dtype=dtype,
|
||||||
computation_dtype=self.torch_dtype,
|
onload_device="cpu",
|
||||||
computation_device=self.device,
|
computation_dtype=self.torch_dtype,
|
||||||
),
|
computation_device=self.device,
|
||||||
)
|
),
|
||||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
)
|
||||||
enable_vram_management(
|
if self.text_encoder_2 is not None:
|
||||||
self.text_encoder_2,
|
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||||
module_map = {
|
enable_vram_management(
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
self.text_encoder_2,
|
||||||
torch.nn.Embedding: AutoWrappedModule,
|
module_map = {
|
||||||
T5LayerNorm: AutoWrappedModule,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
T5DenseActDense: AutoWrappedModule,
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
T5DenseGatedActDense: AutoWrappedModule,
|
T5LayerNorm: AutoWrappedModule,
|
||||||
},
|
T5DenseActDense: AutoWrappedModule,
|
||||||
module_config = dict(
|
T5DenseGatedActDense: AutoWrappedModule,
|
||||||
offload_dtype=dtype,
|
},
|
||||||
offload_device="cpu",
|
module_config = dict(
|
||||||
onload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
onload_device="cpu",
|
offload_device="cpu",
|
||||||
computation_dtype=self.torch_dtype,
|
onload_dtype=dtype,
|
||||||
computation_device=self.device,
|
onload_device="cpu",
|
||||||
),
|
computation_dtype=self.torch_dtype,
|
||||||
)
|
computation_device=self.device,
|
||||||
dtype = next(iter(self.dit.parameters())).dtype
|
),
|
||||||
enable_vram_management(
|
)
|
||||||
self.dit,
|
if self.dit is not None:
|
||||||
module_map = {
|
dtype = next(iter(self.dit.parameters())).dtype
|
||||||
RMSNorm: AutoWrappedModule,
|
enable_vram_management(
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
self.dit,
|
||||||
},
|
module_map = {
|
||||||
module_config = dict(
|
RMSNorm: AutoWrappedModule,
|
||||||
offload_dtype=dtype,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
offload_device="cpu",
|
},
|
||||||
onload_dtype=dtype,
|
module_config = dict(
|
||||||
onload_device="cuda",
|
offload_dtype=dtype,
|
||||||
computation_dtype=self.torch_dtype,
|
offload_device="cpu",
|
||||||
computation_device=self.device,
|
onload_dtype=dtype,
|
||||||
),
|
onload_device="cuda",
|
||||||
max_num_param=num_persistent_param_in_dit,
|
computation_dtype=self.torch_dtype,
|
||||||
overflow_module_config = dict(
|
computation_device=self.device,
|
||||||
offload_dtype=dtype,
|
),
|
||||||
offload_device="cpu",
|
max_num_param=num_persistent_param_in_dit,
|
||||||
onload_dtype=dtype,
|
overflow_module_config = dict(
|
||||||
onload_device="cpu",
|
offload_dtype=dtype,
|
||||||
computation_dtype=self.torch_dtype,
|
offload_device="cpu",
|
||||||
computation_device=self.device,
|
onload_dtype=dtype,
|
||||||
),
|
onload_device="cpu",
|
||||||
)
|
computation_dtype=self.torch_dtype,
|
||||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
computation_device=self.device,
|
||||||
enable_vram_management(
|
),
|
||||||
self.vae_decoder,
|
)
|
||||||
module_map = {
|
if self.vae_decoder is not None:
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
enable_vram_management(
|
||||||
torch.nn.GroupNorm: AutoWrappedModule,
|
self.vae_decoder,
|
||||||
},
|
module_map = {
|
||||||
module_config = dict(
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
offload_dtype=dtype,
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
offload_device="cpu",
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
onload_dtype=dtype,
|
},
|
||||||
onload_device="cpu",
|
module_config = dict(
|
||||||
computation_dtype=self.torch_dtype,
|
offload_dtype=dtype,
|
||||||
computation_device=self.device,
|
offload_device="cpu",
|
||||||
),
|
onload_dtype=dtype,
|
||||||
)
|
onload_device="cpu",
|
||||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
computation_dtype=self.torch_dtype,
|
||||||
enable_vram_management(
|
computation_device=self.device,
|
||||||
self.vae_encoder,
|
),
|
||||||
module_map = {
|
)
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
if self.vae_encoder is not None:
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||||
torch.nn.GroupNorm: AutoWrappedModule,
|
enable_vram_management(
|
||||||
},
|
self.vae_encoder,
|
||||||
module_config = dict(
|
module_map = {
|
||||||
offload_dtype=dtype,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
offload_device="cpu",
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
onload_dtype=dtype,
|
torch.nn.GroupNorm: AutoWrappedModule,
|
||||||
onload_device="cpu",
|
},
|
||||||
computation_dtype=self.torch_dtype,
|
module_config = dict(
|
||||||
computation_device=self.device,
|
offload_dtype=dtype,
|
||||||
),
|
offload_device="cpu",
|
||||||
)
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
self.enable_cpu_offload()
|
self.enable_cpu_offload()
|
||||||
|
|
||||||
|
|
||||||
@@ -162,6 +171,15 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||||
|
|
||||||
|
# InfiniteYou
|
||||||
|
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||||
|
if self.image_proj_model is not None:
|
||||||
|
self.infinityou_processor = InfinitYou(device=self.device)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
self.qwenvl = model_manager.fetch_model("qwenvl")
|
||||||
|
self.step1x_connector = model_manager.fetch_model("step1x_connector")
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||||
@@ -185,10 +203,13 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
|
||||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||||
)
|
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
)
|
||||||
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||||
@@ -349,6 +370,53 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||||
|
if self.infinityou_processor is not None and id_image is not None:
|
||||||
|
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||||
|
else:
|
||||||
|
return {}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
if self.dit.input_dim == 196:
|
||||||
|
if flex_inpaint_image is None:
|
||||||
|
flex_inpaint_image = torch.zeros_like(latents)
|
||||||
|
else:
|
||||||
|
flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
if flex_inpaint_mask is None:
|
||||||
|
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||||
|
else:
|
||||||
|
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||||
|
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||||
|
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||||
|
if flex_control_image is None:
|
||||||
|
flex_control_image = torch.zeros_like(latents)
|
||||||
|
else:
|
||||||
|
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||||
|
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||||
|
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||||
|
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
|
||||||
|
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||||
|
else:
|
||||||
|
flex_kwargs = {}
|
||||||
|
return flex_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
|
||||||
|
if image is None:
|
||||||
|
return {}, {}
|
||||||
|
self.load_models_to_device(["qwenvl", "vae_encoder"])
|
||||||
|
captions = [prompt, negative_prompt]
|
||||||
|
ref_images = [image, image]
|
||||||
|
embs, masks = self.qwenvl(captions, ref_images)
|
||||||
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
image = self.encode_image(image)
|
||||||
|
return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -382,6 +450,17 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
eligen_entity_masks=None,
|
eligen_entity_masks=None,
|
||||||
enable_eligen_on_negative=False,
|
enable_eligen_on_negative=False,
|
||||||
enable_eligen_inpaint=False,
|
enable_eligen_inpaint=False,
|
||||||
|
# InfiniteYou
|
||||||
|
infinityou_id_image=None,
|
||||||
|
infinityou_guidance=1.0,
|
||||||
|
# Flex
|
||||||
|
flex_inpaint_image=None,
|
||||||
|
flex_inpaint_mask=None,
|
||||||
|
flex_control_image=None,
|
||||||
|
flex_control_strength=0.5,
|
||||||
|
flex_control_stop=0.5,
|
||||||
|
# Step1x
|
||||||
|
step1x_reference_image=None,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -409,6 +488,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# Extra input
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||||
|
|
||||||
|
# InfiniteYou
|
||||||
|
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
|
||||||
|
|
||||||
# Entity control
|
# Entity control
|
||||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||||
|
|
||||||
@@ -418,19 +500,25 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# ControlNets
|
# ControlNets
|
||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||||
|
|
||||||
|
# Flex
|
||||||
|
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(['dit', 'controlnet'])
|
self.load_models_to_device(['dit', 'controlnet', 'step1x_connector'])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# Positive side
|
# Positive side
|
||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -445,9 +533,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
# Negative side
|
# Negative side
|
||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -469,6 +557,58 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InfinitYou:
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
from facexlib.recognition import init_recognition_model
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
insightface_root_path = 'models/InfiniteYou/insightface'
|
||||||
|
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
|
||||||
|
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||||
|
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||||
|
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||||
|
|
||||||
|
def _detect_face(self, id_image_cv2):
|
||||||
|
face_info = self.app_640.get(id_image_cv2)
|
||||||
|
if len(face_info) > 0:
|
||||||
|
return face_info
|
||||||
|
face_info = self.app_320.get(id_image_cv2)
|
||||||
|
if len(face_info) > 0:
|
||||||
|
return face_info
|
||||||
|
face_info = self.app_160.get(id_image_cv2)
|
||||||
|
return face_info
|
||||||
|
|
||||||
|
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||||
|
from insightface.utils import face_align
|
||||||
|
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||||
|
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||||
|
arc_face_image = 2 * arc_face_image - 1
|
||||||
|
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||||
|
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||||
|
return face_emb
|
||||||
|
|
||||||
|
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||||
|
import cv2
|
||||||
|
if id_image is None:
|
||||||
|
return {'id_emb': None}, controlnet_image
|
||||||
|
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
|
||||||
|
face_info = self._detect_face(id_image_cv2)
|
||||||
|
if len(face_info) == 0:
|
||||||
|
raise ValueError('No face detected in the input ID image')
|
||||||
|
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||||
|
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||||
|
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||||
|
if controlnet_image is None:
|
||||||
|
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||||
|
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
@@ -515,6 +655,7 @@ class TeaCache:
|
|||||||
def lets_dance_flux(
|
def lets_dance_flux(
|
||||||
dit: FluxDiT,
|
dit: FluxDiT,
|
||||||
controlnet: FluxMultiControlNetManager = None,
|
controlnet: FluxMultiControlNetManager = None,
|
||||||
|
step1x_connector: Qwen2Connector = None,
|
||||||
hidden_states=None,
|
hidden_states=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_emb=None,
|
prompt_emb=None,
|
||||||
@@ -529,6 +670,14 @@ def lets_dance_flux(
|
|||||||
entity_prompt_emb=None,
|
entity_prompt_emb=None,
|
||||||
entity_masks=None,
|
entity_masks=None,
|
||||||
ipadapter_kwargs_list={},
|
ipadapter_kwargs_list={},
|
||||||
|
id_emb=None,
|
||||||
|
infinityou_guidance=None,
|
||||||
|
flex_condition=None,
|
||||||
|
flex_uncondition=None,
|
||||||
|
flex_control_stop_timestep=None,
|
||||||
|
step1x_llm_embedding=None,
|
||||||
|
step1x_mask=None,
|
||||||
|
step1x_reference_latents=None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -573,10 +722,25 @@ def lets_dance_flux(
|
|||||||
"tile_size": tile_size,
|
"tile_size": tile_size,
|
||||||
"tile_stride": tile_stride,
|
"tile_stride": tile_stride,
|
||||||
}
|
}
|
||||||
|
if id_emb is not None:
|
||||||
|
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
||||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||||
controlnet_frames, **controlnet_extra_kwargs
|
controlnet_frames, **controlnet_extra_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Flex
|
||||||
|
if flex_condition is not None:
|
||||||
|
if timestep.tolist()[0] >= flex_control_stop_timestep:
|
||||||
|
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
if step1x_llm_embedding is not None:
|
||||||
|
prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
|
||||||
|
text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
|
||||||
|
|
||||||
if image_ids is None:
|
if image_ids is None:
|
||||||
image_ids = dit.prepare_image_ids(hidden_states)
|
image_ids = dit.prepare_image_ids(hidden_states)
|
||||||
|
|
||||||
@@ -587,6 +751,14 @@ def lets_dance_flux(
|
|||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
height, width = hidden_states.shape[-2:]
|
||||||
hidden_states = dit.patchify(hidden_states)
|
hidden_states = dit.patchify(hidden_states)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
if step1x_reference_latents is not None:
|
||||||
|
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
|
||||||
|
step1x_reference_latents = dit.patchify(step1x_reference_latents)
|
||||||
|
image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
|
||||||
|
hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
|
||||||
|
|
||||||
hidden_states = dit.x_embedder(hidden_states)
|
hidden_states = dit.x_embedder(hidden_states)
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
if entity_prompt_emb is not None and entity_masks is not None:
|
||||||
@@ -641,6 +813,11 @@ def lets_dance_flux(
|
|||||||
|
|
||||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||||
hidden_states = dit.final_proj_out(hidden_states)
|
hidden_states = dit.final_proj_out(hidden_states)
|
||||||
|
|
||||||
|
# Step1x
|
||||||
|
if step1x_reference_latents is not None:
|
||||||
|
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
|
||||||
|
|
||||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
1052
diffsynth/pipelines/flux_image_new.py
Normal file
1052
diffsynth/pipelines/flux_image_new.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
|
|||||||
from .base import BasePipeline
|
from .base import BasePipeline
|
||||||
from ..prompters import HunyuanVideoPrompter
|
from ..prompters import HunyuanVideoPrompter
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoPipeline(BasePipeline):
|
class HunyuanVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
pipe.enable_vram_management()
|
pipe.enable_vram_management()
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
|
||||||
|
num_patches = round((base_size / patch_size)**2)
|
||||||
|
assert max_ratio >= 1.0
|
||||||
|
crop_size_list = []
|
||||||
|
wp, hp = num_patches, 1
|
||||||
|
while wp > 0:
|
||||||
|
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||||
|
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||||
|
if (hp + 1) * wp <= num_patches:
|
||||||
|
hp += 1
|
||||||
|
else:
|
||||||
|
wp -= 1
|
||||||
|
return crop_size_list
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
|
||||||
|
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
|
||||||
|
aspect_ratio = float(height) / float(width)
|
||||||
|
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
||||||
|
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||||
|
return buckets[closest_ratio_id], float(closest_ratio)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
|
||||||
|
if i2v_resolution == "720p":
|
||||||
|
bucket_hw_base_size = 960
|
||||||
|
elif i2v_resolution == "540p":
|
||||||
|
bucket_hw_base_size = 720
|
||||||
|
elif i2v_resolution == "360p":
|
||||||
|
bucket_hw_base_size = 480
|
||||||
|
else:
|
||||||
|
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
||||||
|
origin_size = semantic_images[0].size
|
||||||
|
|
||||||
|
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
|
||||||
|
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||||
|
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
||||||
|
ref_image_transform = transforms.Compose([
|
||||||
|
transforms.Resize(closest_size),
|
||||||
|
transforms.CenterCrop(closest_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5])
|
||||||
|
])
|
||||||
|
|
||||||
|
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
||||||
|
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
||||||
|
target_height, target_width = closest_size
|
||||||
|
return semantic_image_pixel_values, target_height, target_width
|
||||||
|
|
||||||
|
|
||||||
|
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
|
||||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
|
||||||
)
|
)
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||||
|
|
||||||
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
prompt,
|
prompt,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
input_video=None,
|
input_video=None,
|
||||||
|
input_images=None,
|
||||||
|
i2v_resolution="720p",
|
||||||
|
i2v_stability=True,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
seed=None,
|
seed=None,
|
||||||
rand_device=None,
|
rand_device=None,
|
||||||
@@ -109,6 +160,13 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
|
# encoder input images
|
||||||
|
if input_images is not None:
|
||||||
|
self.load_models_to_device(['vae_encoder'])
|
||||||
|
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
|
||||||
|
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
|
||||||
|
image_latents = self.vae_encoder(image_pixel_values)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
rand_device = self.device if rand_device is None else rand_device
|
rand_device = self.device if rand_device is None else rand_device
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||||
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
input_video = torch.stack(input_video, dim=2)
|
input_video = torch.stack(input_video, dim=2)
|
||||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
elif input_images is not None and i2v_stability:
|
||||||
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
|
||||||
|
t = torch.tensor([0.999]).to(device=self.device)
|
||||||
|
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
|
||||||
|
latents = latents.to(dtype=image_latents.dtype)
|
||||||
else:
|
else:
|
||||||
latents = noise
|
latents = noise
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
# current mllm does not support vram_management
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
|
||||||
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||||
|
|
||||||
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||||
|
|
||||||
|
forward_func = lets_dance_hunyuan_video
|
||||||
|
if input_images is not None:
|
||||||
|
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
|
||||||
|
forward_func = lets_dance_hunyuan_video_i2v
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
if input_images is not None:
|
||||||
|
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
|
||||||
|
latents = torch.concat([image_latents, latents], dim=2)
|
||||||
|
else:
|
||||||
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae_decoder'])
|
self.load_models_to_device(['vae_decoder'])
|
||||||
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
|
|||||||
print("TeaCache skip forward.")
|
print("TeaCache skip forward.")
|
||||||
img = tea_cache.update(img)
|
img = tea_cache.update(img)
|
||||||
else:
|
else:
|
||||||
|
split_token = int(text_mask.sum(dim=1))
|
||||||
|
txt_len = int(txt.shape[1])
|
||||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
|
||||||
|
|
||||||
x = torch.concat([img, txt], dim=1)
|
x = torch.concat([img, txt], dim=1)
|
||||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
|
||||||
img = x[:, :-256]
|
img = x[:, :-txt_len]
|
||||||
|
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache.store(img)
|
||||||
|
img = dit.final_layer(img, vec)
|
||||||
|
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def lets_dance_hunyuan_video_i2v(
|
||||||
|
dit: HunyuanVideoDiT,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
prompt_emb: torch.Tensor = None,
|
||||||
|
text_mask: torch.Tensor = None,
|
||||||
|
pooled_prompt_emb: torch.Tensor = None,
|
||||||
|
freqs_cos: torch.Tensor = None,
|
||||||
|
freqs_sin: torch.Tensor = None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
tea_cache: TeaCache = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
# Uncomment below to keep same as official implementation
|
||||||
|
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
|
||||||
|
vec = dit.time_in(t, dtype=torch.bfloat16)
|
||||||
|
vec_2 = dit.vector_in(pooled_prompt_emb)
|
||||||
|
vec = vec + vec_2
|
||||||
|
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
|
||||||
|
tr_token = (H // 2) * (W // 2)
|
||||||
|
token_replace_vec = token_replace_vec + vec_2
|
||||||
|
|
||||||
|
img = dit.img_in(x)
|
||||||
|
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||||
|
else:
|
||||||
|
tea_cache_update = False
|
||||||
|
|
||||||
|
if tea_cache_update:
|
||||||
|
print("TeaCache skip forward.")
|
||||||
|
img = tea_cache.update(img)
|
||||||
|
else:
|
||||||
|
split_token = int(text_mask.sum(dim=1))
|
||||||
|
txt_len = int(txt.shape[1])
|
||||||
|
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||||
|
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
|
||||||
|
|
||||||
|
x = torch.concat([img, txt], dim=1)
|
||||||
|
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||||
|
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
|
||||||
|
img = x[:, :-txt_len]
|
||||||
|
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(img)
|
tea_cache.store(img)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class OmniGenCache(DynamicCache):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
|
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
|
print("No available GPU, offload_kv_cache will be set to False, which will result in large memory usage and time cost when input multiple images!!!")
|
||||||
offload_kv_cache = False
|
offload_kv_cache = False
|
||||||
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
617
diffsynth/pipelines/wan_video.py
Normal file
617
diffsynth/pipelines/wan_video.py
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
import types
|
||||||
|
from ..models import ModelManager
|
||||||
|
from ..models.wan_video_dit import WanModel
|
||||||
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
|
from ..schedulers.flow_match import FlowMatchScheduler
|
||||||
|
from .base import BasePipeline
|
||||||
|
from ..prompters import WanPrompter
|
||||||
|
import torch, os
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
|
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||||
|
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||||
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
|
||||||
|
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||||
|
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||||||
|
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
||||||
|
self.text_encoder: WanTextEncoder = None
|
||||||
|
self.image_encoder: WanImageEncoder = None
|
||||||
|
self.dit: WanModel = None
|
||||||
|
self.vae: WanVideoVAE = None
|
||||||
|
self.motion_controller: WanMotionControllerModel = None
|
||||||
|
self.vace: VaceWanModel = None
|
||||||
|
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
|
||||||
|
self.height_division_factor = 16
|
||||||
|
self.width_division_factor = 16
|
||||||
|
self.use_unified_sequence_parallel = False
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
|
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.text_encoder,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Embedding: AutoWrappedModule,
|
||||||
|
T5RelativeEmbedding: AutoWrappedModule,
|
||||||
|
T5LayerNorm: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dtype = next(iter(self.dit.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.dit,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
|
RMSNorm: AutoWrappedModule,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device=self.device,
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
max_num_param=num_persistent_param_in_dit,
|
||||||
|
overflow_module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dtype = next(iter(self.vae.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.vae,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
RMS_norm: AutoWrappedModule,
|
||||||
|
CausalConv3d: AutoWrappedModule,
|
||||||
|
Upsample: AutoWrappedModule,
|
||||||
|
torch.nn.SiLU: AutoWrappedModule,
|
||||||
|
torch.nn.Dropout: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device=self.device,
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.image_encoder is not None:
|
||||||
|
dtype = next(iter(self.image_encoder.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.image_encoder,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv2d: AutoWrappedModule,
|
||||||
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.motion_controller is not None:
|
||||||
|
dtype = next(iter(self.motion_controller.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.motion_controller,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.vace is not None:
|
||||||
|
enable_vram_management(
|
||||||
|
self.vace,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
|
RMSNorm: AutoWrappedModule,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device=self.device,
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.enable_cpu_offload()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_models(self, model_manager: ModelManager):
|
||||||
|
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
|
||||||
|
if text_encoder_model_and_path is not None:
|
||||||
|
self.text_encoder, tokenizer_path = text_encoder_model_and_path
|
||||||
|
self.prompter.fetch_models(self.text_encoder)
|
||||||
|
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
|
||||||
|
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||||
|
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
|
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
|
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||||
|
self.vace = model_manager.fetch_model("wan_video_vace")
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||||
|
if device is None: device = model_manager.device
|
||||||
|
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||||
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
pipe.fetch_models(model_manager)
|
||||||
|
if use_usp:
|
||||||
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
|
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||||
|
|
||||||
|
for block in pipe.dit.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
||||||
|
pipe.sp_size = get_sequence_parallel_world_size()
|
||||||
|
pipe.use_unified_sequence_parallel = True
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def denoising_model(self):
|
||||||
|
return self.dit
|
||||||
|
|
||||||
|
|
||||||
|
def encode_prompt(self, prompt, positive=True):
|
||||||
|
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
|
||||||
|
return {"context": prompt_emb}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||||
|
clip_context = self.image_encoder.encode_image([image])
|
||||||
|
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||||
|
msk[:, 1:] = 0
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
||||||
|
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||||
|
if self.dit.has_image_pos_emb:
|
||||||
|
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
|
||||||
|
msk[:, -1:] = 1
|
||||||
|
else:
|
||||||
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||||
|
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
y = torch.concat([msk, y])
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
control_video = self.preprocess_images(control_video)
|
||||||
|
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_reference_image(self, reference_image, height, width):
|
||||||
|
if reference_image is not None:
|
||||||
|
self.load_models_to_device(["vae"])
|
||||||
|
reference_image = reference_image.resize((width, height))
|
||||||
|
reference_image = self.preprocess_images([reference_image])
|
||||||
|
reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
reference_latents = self.vae.encode(reference_image, device=self.device)
|
||||||
|
return {"reference_latents": reference_latents}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
if control_video is not None:
|
||||||
|
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
if clip_feature is None or y is None:
|
||||||
|
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
|
||||||
|
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
y = y[:, -16:]
|
||||||
|
y = torch.concat([control_latents, y], dim=1)
|
||||||
|
return {"clip_feature": clip_feature, "y": y}
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2video(self, frames):
|
||||||
|
frames = rearrange(frames, "C T H W -> T H W C")
|
||||||
|
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||||
|
frames = [Image.fromarray(frame) for frame in frames]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_extra_input(self, latents=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_unified_sequence_parallel(self):
|
||||||
|
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||||
|
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return {"motion_bucket_id": motion_bucket_id}
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_vace_kwargs(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
|
||||||
|
height=480, width=832, num_frames=81,
|
||||||
|
seed=None, rand_device="cpu",
|
||||||
|
tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
|
||||||
|
):
|
||||||
|
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
||||||
|
self.load_models_to_device(["vae"])
|
||||||
|
if vace_video is None:
|
||||||
|
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
vace_video = self.preprocess_images(vace_video)
|
||||||
|
vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
|
if vace_mask is None:
|
||||||
|
vace_mask = torch.ones_like(vace_video)
|
||||||
|
else:
|
||||||
|
vace_mask = self.preprocess_images(vace_mask)
|
||||||
|
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
|
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||||
|
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||||
|
inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||||
|
|
||||||
|
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||||
|
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
||||||
|
|
||||||
|
if vace_reference_image is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
vace_reference_image = self.preprocess_images([vace_reference_image])
|
||||||
|
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||||
|
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||||
|
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||||
|
|
||||||
|
noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
|
||||||
|
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
latents = torch.concat((noise, latents), dim=2)
|
||||||
|
|
||||||
|
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||||
|
return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||||
|
else:
|
||||||
|
return latents, {"vace_context": None, "vace_scale": vace_scale}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
negative_prompt="",
|
||||||
|
input_image=None,
|
||||||
|
end_image=None,
|
||||||
|
input_video=None,
|
||||||
|
control_video=None,
|
||||||
|
reference_image=None,
|
||||||
|
vace_video=None,
|
||||||
|
vace_video_mask=None,
|
||||||
|
vace_reference_image=None,
|
||||||
|
vace_scale=1.0,
|
||||||
|
denoising_strength=1.0,
|
||||||
|
seed=None,
|
||||||
|
rand_device="cpu",
|
||||||
|
height=480,
|
||||||
|
width=832,
|
||||||
|
num_frames=81,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
num_inference_steps=50,
|
||||||
|
sigma_shift=5.0,
|
||||||
|
motion_bucket_id=None,
|
||||||
|
tiled=True,
|
||||||
|
tile_size=(30, 52),
|
||||||
|
tile_stride=(15, 26),
|
||||||
|
tea_cache_l1_thresh=None,
|
||||||
|
tea_cache_model_id="",
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
progress_bar_st=None,
|
||||||
|
):
|
||||||
|
# Parameter check
|
||||||
|
height, width = self.check_resize_height_width(height, width)
|
||||||
|
if num_frames % 4 != 1:
|
||||||
|
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||||
|
print(f"Only `num_frames % 4 == 1` is acceptable. We round it up to {num_frames}.")
|
||||||
|
|
||||||
|
# Tiler parameters
|
||||||
|
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
|
# Initialize noise
|
||||||
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||||
|
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
if input_video is not None:
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
input_video = self.preprocess_images(input_video)
|
||||||
|
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
else:
|
||||||
|
latents = noise
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
|
self.load_models_to_device(["text_encoder"])
|
||||||
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||||
|
|
||||||
|
# Encode image
|
||||||
|
if input_image is not None and self.image_encoder is not None:
|
||||||
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
|
image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs)
|
||||||
|
else:
|
||||||
|
image_emb = {}
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
|
||||||
|
|
||||||
|
# ControlNet
|
||||||
|
if control_video is not None:
|
||||||
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
|
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
|
||||||
|
|
||||||
|
# Motion Controller
|
||||||
|
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||||
|
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
|
||||||
|
else:
|
||||||
|
motion_kwargs = {}
|
||||||
|
|
||||||
|
# Extra input
|
||||||
|
extra_input = self.prepare_extra_input(latents)
|
||||||
|
|
||||||
|
# VACE
|
||||||
|
latents, vace_kwargs = self.prepare_vace_kwargs(
|
||||||
|
latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
|
||||||
|
height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
|
# Unified Sequence Parallel
|
||||||
|
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(["dit", "motion_controller", "vace"])
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
noise_pred_posi = model_fn_wan_video(
|
||||||
|
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||||
|
x=latents, timestep=timestep,
|
||||||
|
**prompt_emb_posi, **image_emb, **extra_input,
|
||||||
|
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||||
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
noise_pred_nega = model_fn_wan_video(
|
||||||
|
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||||
|
x=latents, timestep=timestep,
|
||||||
|
**prompt_emb_nega, **image_emb, **extra_input,
|
||||||
|
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
|
if vace_reference_image is not None:
|
||||||
|
latents = latents[:, :, 1:]
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
frames = self.decode_video(latents, **tiler_kwargs)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
frames = self.tensor2video(frames[0])
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TeaCache:
|
||||||
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.step = 0
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = None
|
||||||
|
self.rel_l1_thresh = rel_l1_thresh
|
||||||
|
self.previous_residual = None
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
|
self.coefficients_dict = {
|
||||||
|
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||||
|
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||||
|
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||||
|
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||||
|
}
|
||||||
|
if model_id not in self.coefficients_dict:
|
||||||
|
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||||
|
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||||
|
self.coefficients = self.coefficients_dict[model_id]
|
||||||
|
|
||||||
|
def check(self, dit: WanModel, x, t_mod):
|
||||||
|
modulated_inp = t_mod.clone()
|
||||||
|
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
else:
|
||||||
|
coefficients = self.coefficients
|
||||||
|
rescale_func = np.poly1d(coefficients)
|
||||||
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
|
should_calc = False
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = modulated_inp
|
||||||
|
self.step += 1
|
||||||
|
if self.step == self.num_inference_steps:
|
||||||
|
self.step = 0
|
||||||
|
if should_calc:
|
||||||
|
self.previous_hidden_states = x.clone()
|
||||||
|
return not should_calc
|
||||||
|
|
||||||
|
def store(self, hidden_states):
|
||||||
|
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
|
def update(self, hidden_states):
|
||||||
|
hidden_states = hidden_states + self.previous_residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_wan_video(
|
||||||
|
dit: WanModel,
|
||||||
|
motion_controller: WanMotionControllerModel = None,
|
||||||
|
vace: VaceWanModel = None,
|
||||||
|
x: torch.Tensor = None,
|
||||||
|
timestep: torch.Tensor = None,
|
||||||
|
context: torch.Tensor = None,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
reference_latents = None,
|
||||||
|
vace_context = None,
|
||||||
|
vace_scale = 1.0,
|
||||||
|
tea_cache: TeaCache = None,
|
||||||
|
use_unified_sequence_parallel: bool = False,
|
||||||
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
import torch.distributed as dist
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
|
||||||
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||||
|
if motion_bucket_id is not None and motion_controller is not None:
|
||||||
|
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||||
|
context = dit.text_embedding(context)
|
||||||
|
|
||||||
|
if dit.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = dit.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = dit.patchify(x)
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
if reference_latents is not None:
|
||||||
|
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
||||||
|
x = torch.concat([reference_latents, x], dim=1)
|
||||||
|
f += 1
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||||
|
else:
|
||||||
|
tea_cache_update = False
|
||||||
|
|
||||||
|
if vace_context is not None:
|
||||||
|
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
if tea_cache_update:
|
||||||
|
x = tea_cache.update(x)
|
||||||
|
else:
|
||||||
|
for block_id, block in enumerate(dit.blocks):
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
|
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache.store(x)
|
||||||
|
|
||||||
|
if reference_latents is not None:
|
||||||
|
x = x[:, reference_latents.shape[1]:]
|
||||||
|
f -= 1
|
||||||
|
|
||||||
|
x = dit.head(x, t)
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
1205
diffsynth/pipelines/wan_video_new.py
Normal file
1205
diffsynth/pipelines/wan_video_new.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,4 +9,4 @@ from .omost import OmostPromter
|
|||||||
from .cog_prompter import CogPrompter
|
from .cog_prompter import CogPrompter
|
||||||
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
||||||
from .stepvideo_prompter import StepVideoPrompter
|
from .stepvideo_prompter import StepVideoPrompter
|
||||||
from .wanx_prompter import WanXPrompter
|
from .wan_prompter import WanPrompter
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from .base_prompter import BasePrompter
|
from .base_prompter import BasePrompter
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
|
||||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
|
||||||
import os, torch
|
import os, torch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
PROMPT_TEMPLATE_ENCODE = (
|
PROMPT_TEMPLATE_ENCODE = (
|
||||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||||
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
|||||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE_ENCODE_I2V = (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
||||||
|
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
|
"1. The main content and theme of the video."
|
||||||
|
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||||
|
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||||
|
"4. background environment, light, style and atmosphere."
|
||||||
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
PROMPT_TEMPLATE = {
|
PROMPT_TEMPLATE = {
|
||||||
"dit-llm-encode": {
|
"dit-llm-encode": {
|
||||||
"template": PROMPT_TEMPLATE_ENCODE,
|
"template": PROMPT_TEMPLATE_ENCODE,
|
||||||
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
|
|||||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||||
"crop_start": 95,
|
"crop_start": 95,
|
||||||
},
|
},
|
||||||
|
"dit-llm-encode-i2v": {
|
||||||
|
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
||||||
|
"crop_start": 36,
|
||||||
|
"image_emb_start": 5,
|
||||||
|
"image_emb_end": 581,
|
||||||
|
"image_emb_len": 576,
|
||||||
|
"double_return_token_id": 271
|
||||||
|
},
|
||||||
|
"dit-llm-encode-video-i2v": {
|
||||||
|
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||||
|
"crop_start": 103,
|
||||||
|
"image_emb_start": 5,
|
||||||
|
"image_emb_end": 581,
|
||||||
|
"image_emb_len": 576,
|
||||||
|
"double_return_token_id": 271
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||||
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||||
|
|
||||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
def fetch_models(self,
|
||||||
|
text_encoder_1: SD3TextEncoder1 = None,
|
||||||
|
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
|
||||||
self.text_encoder_1 = text_encoder_1
|
self.text_encoder_1 = text_encoder_1
|
||||||
self.text_encoder_2 = text_encoder_2
|
self.text_encoder_2 = text_encoder_2
|
||||||
|
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
|
||||||
|
# processor
|
||||||
|
# TODO: may need to replace processor with local implementation
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||||
|
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
|
||||||
|
# template
|
||||||
|
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
|
||||||
|
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
|
||||||
|
|
||||||
def apply_text_to_template(self, text, template):
|
def apply_text_to_template(self, text, template):
|
||||||
assert isinstance(template, str)
|
assert isinstance(template, str)
|
||||||
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
|
|
||||||
return last_hidden_state, attention_mask
|
return last_hidden_state, attention_mask
|
||||||
|
|
||||||
|
def encode_prompt_using_mllm(self,
|
||||||
|
prompt,
|
||||||
|
images,
|
||||||
|
max_length,
|
||||||
|
device,
|
||||||
|
crop_start,
|
||||||
|
hidden_state_skip_layer=2,
|
||||||
|
use_attention_mask=True,
|
||||||
|
image_embed_interleave=4):
|
||||||
|
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
|
||||||
|
max_length += crop_start
|
||||||
|
inputs = self.tokenizer_2(prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True)
|
||||||
|
input_ids = inputs.input_ids.to(device)
|
||||||
|
attention_mask = inputs.attention_mask.to(device)
|
||||||
|
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
hidden_state_skip_layer=hidden_state_skip_layer,
|
||||||
|
pixel_values=image_outputs)
|
||||||
|
|
||||||
|
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||||
|
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
|
||||||
|
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
|
||||||
|
batch_indices, last_double_return_token_indices = torch.where(
|
||||||
|
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
|
||||||
|
if last_double_return_token_indices.shape[0] == 3:
|
||||||
|
# in case the prompt is too long
|
||||||
|
last_double_return_token_indices = torch.cat((
|
||||||
|
last_double_return_token_indices,
|
||||||
|
torch.tensor([input_ids.shape[-1]]),
|
||||||
|
))
|
||||||
|
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
||||||
|
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
|
||||||
|
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
|
||||||
|
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
|
||||||
|
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||||
|
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
|
||||||
|
attention_mask_assistant_crop_end = last_double_return_token_indices
|
||||||
|
text_last_hidden_state = []
|
||||||
|
text_attention_mask = []
|
||||||
|
image_last_hidden_state = []
|
||||||
|
image_attention_mask = []
|
||||||
|
for i in range(input_ids.shape[0]):
|
||||||
|
text_last_hidden_state.append(
|
||||||
|
torch.cat([
|
||||||
|
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
|
||||||
|
last_hidden_state[i, assistant_crop_end[i].item():],
|
||||||
|
]))
|
||||||
|
text_attention_mask.append(
|
||||||
|
torch.cat([
|
||||||
|
attention_mask[
|
||||||
|
i,
|
||||||
|
crop_start:attention_mask_assistant_crop_start[i].item(),
|
||||||
|
],
|
||||||
|
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
|
||||||
|
]) if use_attention_mask else None)
|
||||||
|
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
|
||||||
|
image_attention_mask.append(
|
||||||
|
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
|
||||||
|
to(attention_mask.dtype) if use_attention_mask else None)
|
||||||
|
|
||||||
|
text_last_hidden_state = torch.stack(text_last_hidden_state)
|
||||||
|
text_attention_mask = torch.stack(text_attention_mask)
|
||||||
|
image_last_hidden_state = torch.stack(image_last_hidden_state)
|
||||||
|
image_attention_mask = torch.stack(image_attention_mask)
|
||||||
|
|
||||||
|
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
|
||||||
|
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
|
||||||
|
|
||||||
|
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
|
||||||
|
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
|
||||||
|
|
||||||
|
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
|
||||||
|
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
|
||||||
|
|
||||||
|
return last_hidden_state, attention_mask
|
||||||
|
|
||||||
def encode_prompt(self,
|
def encode_prompt(self,
|
||||||
prompt,
|
prompt,
|
||||||
|
images=None,
|
||||||
positive=True,
|
positive=True,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
clip_sequence_length=77,
|
clip_sequence_length=77,
|
||||||
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
data_type='video',
|
data_type='video',
|
||||||
use_template=True,
|
use_template=True,
|
||||||
hidden_state_skip_layer=2,
|
hidden_state_skip_layer=2,
|
||||||
use_attention_mask=True):
|
use_attention_mask=True,
|
||||||
|
image_embed_interleave=4):
|
||||||
|
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
prompt = self.process_prompt(prompt, positive=positive)
|
||||||
|
|
||||||
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
if images is None:
|
||||||
prompt_formated, llm_sequence_length, device, crop_start,
|
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
|
||||||
hidden_state_skip_layer, use_attention_mask)
|
hidden_state_skip_layer, use_attention_mask)
|
||||||
|
else:
|
||||||
|
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
|
||||||
|
crop_start, hidden_state_skip_layer, use_attention_mask,
|
||||||
|
image_embed_interleave)
|
||||||
|
|
||||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from .base_prompter import BasePrompter
|
from .base_prompter import BasePrompter
|
||||||
from ..models.wanx_text_encoder import WanXTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
import os, torch
|
import os, torch
|
||||||
import ftfy
|
import ftfy
|
||||||
import html
|
import html
|
||||||
import string
|
import string
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
|
|
||||||
@@ -14,11 +13,13 @@ def basic_clean(text):
|
|||||||
text = html.unescape(html.unescape(text))
|
text = html.unescape(html.unescape(text))
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def whitespace_clean(text):
|
def whitespace_clean(text):
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r'\s+', ' ', text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||||
text = text.replace('_', ' ')
|
text = text.replace('_', ' ')
|
||||||
if keep_punctuation_exact_string:
|
if keep_punctuation_exact_string:
|
||||||
@@ -31,6 +32,7 @@ def canonicalize(text, keep_punctuation_exact_string=None):
|
|||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r'\s+', ' ', text)
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceTokenizer:
|
class HuggingfaceTokenizer:
|
||||||
|
|
||||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||||
@@ -78,26 +80,30 @@ class HuggingfaceTokenizer:
|
|||||||
text = canonicalize(basic_clean(text))
|
text = canonicalize(basic_clean(text))
|
||||||
return text
|
return text
|
||||||
|
|
||||||
class WanXPrompter(BasePrompter):
|
|
||||||
|
class WanPrompter(BasePrompter):
|
||||||
|
|
||||||
def __init__(self, tokenizer_path=None, text_len=512):
|
def __init__(self, tokenizer_path=None, text_len=512):
|
||||||
if tokenizer_path is None:
|
|
||||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
||||||
tokenizer_path = os.path.join(
|
|
||||||
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
self.text_len = text_len
|
||||||
self.text_encoder = None
|
self.text_encoder = None
|
||||||
|
self.fetch_tokenizer(tokenizer_path)
|
||||||
|
|
||||||
def fetch_models(self, text_encoder: WanXTextEncoder = None):
|
def fetch_tokenizer(self, tokenizer_path=None):
|
||||||
|
if tokenizer_path is not None:
|
||||||
|
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
||||||
|
|
||||||
|
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
|
|
||||||
def encode_prompt(self, prompt, device="cuda"):
|
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
||||||
|
prompt = self.process_prompt(prompt, positive=positive)
|
||||||
|
|
||||||
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
||||||
ids = ids.to(device)
|
ids = ids.to(device)
|
||||||
mask = mask.to(device)
|
mask = mask.to(device)
|
||||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||||
prompt_emb = self.text_encoder(ids, mask)
|
prompt_emb = self.text_encoder(ids, mask)
|
||||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
for i, v in enumerate(seq_lens):
|
||||||
|
prompt_emb[:, v:] = 0
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|
||||||
@@ -15,7 +15,9 @@ class FlowMatchScheduler():
|
|||||||
self.set_timesteps(num_inference_steps)
|
self.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
||||||
|
if shift is not None:
|
||||||
|
self.shift = shift
|
||||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||||
if self.extra_one_step:
|
if self.extra_one_step:
|
||||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||||
@@ -33,9 +35,12 @@ class FlowMatchScheduler():
|
|||||||
y_shifted = y - y.min()
|
y_shifted = y - y.min()
|
||||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||||
self.linear_timesteps_weights = bsmntw_weighing
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
self.training = True
|
||||||
|
else:
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
|
||||||
def step(self, model_output, timestep, sample, to_final=False):
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
if isinstance(timestep, torch.Tensor):
|
if isinstance(timestep, torch.Tensor):
|
||||||
timestep = timestep.cpu()
|
timestep = timestep.cpu()
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"_valid_processor_keys": [
|
||||||
|
"images",
|
||||||
|
"do_resize",
|
||||||
|
"size",
|
||||||
|
"resample",
|
||||||
|
"do_center_crop",
|
||||||
|
"crop_size",
|
||||||
|
"do_rescale",
|
||||||
|
"rescale_factor",
|
||||||
|
"do_normalize",
|
||||||
|
"image_mean",
|
||||||
|
"image_std",
|
||||||
|
"do_convert_rgb",
|
||||||
|
"return_tensors",
|
||||||
|
"data_format",
|
||||||
|
"input_data_format"
|
||||||
|
],
|
||||||
|
"crop_size": {
|
||||||
|
"height": 336,
|
||||||
|
"width": 336
|
||||||
|
},
|
||||||
|
"do_center_crop": true,
|
||||||
|
"do_convert_rgb": true,
|
||||||
|
"do_normalize": true,
|
||||||
|
"do_rescale": true,
|
||||||
|
"do_resize": true,
|
||||||
|
"image_mean": [
|
||||||
|
0.48145466,
|
||||||
|
0.4578275,
|
||||||
|
0.40821073
|
||||||
|
],
|
||||||
|
"image_processor_type": "CLIPImageProcessor",
|
||||||
|
"image_std": [
|
||||||
|
0.26862954,
|
||||||
|
0.26130258,
|
||||||
|
0.27577711
|
||||||
|
],
|
||||||
|
"processor_class": "LlavaProcessor",
|
||||||
|
"resample": 3,
|
||||||
|
"rescale_factor": 0.00392156862745098,
|
||||||
|
"size": {
|
||||||
|
"shortest_edge": 336
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -250,6 +250,17 @@ def add_general_parsers(parser):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Pretrained LoRA path. Required if the training is resumed.",
|
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_swanlab",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use SwanLab logger.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--swanlab_mode",
|
||||||
|
default=None,
|
||||||
|
help="SwanLab mode (cloud or local).",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -269,8 +280,21 @@ def launch_training_task(model, args):
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_workers=args.dataloader_num_workers
|
num_workers=args.dataloader_num_workers
|
||||||
)
|
)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
|
if args.use_swanlab:
|
||||||
|
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||||
|
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||||
|
swanlab_config.update(vars(args))
|
||||||
|
swanlab_logger = SwanLabLogger(
|
||||||
|
project="diffsynth_studio",
|
||||||
|
name="diffsynth_studio",
|
||||||
|
config=swanlab_config,
|
||||||
|
mode=args.swanlab_mode,
|
||||||
|
logdir=os.path.join(args.output_path, "swanlog"),
|
||||||
|
)
|
||||||
|
logger = [swanlab_logger]
|
||||||
|
else:
|
||||||
|
logger = None
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
@@ -279,7 +303,8 @@ def launch_training_task(model, args):
|
|||||||
strategy=args.training_strategy,
|
strategy=args.training_strategy,
|
||||||
default_root_dir=args.output_path,
|
default_root_dir=args.output_path,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=logger,
|
||||||
)
|
)
|
||||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
|
|
||||||
|
|||||||
465
diffsynth/trainers/utils.py
Normal file
465
diffsynth/trainers/utils.py
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
from PIL import Image
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_path=None, metadata_path=None,
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
data_file_keys=("image",),
|
||||||
|
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||||
|
repeat=1,
|
||||||
|
args=None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
base_path = args.dataset_base_path
|
||||||
|
metadata_path = args.dataset_metadata_path
|
||||||
|
height = args.height
|
||||||
|
width = args.width
|
||||||
|
max_pixels = args.max_pixels
|
||||||
|
data_file_keys = args.data_file_keys.split(",")
|
||||||
|
repeat = args.dataset_repeat
|
||||||
|
|
||||||
|
self.base_path = base_path
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
self.data_file_keys = data_file_keys
|
||||||
|
self.image_file_extension = image_file_extension
|
||||||
|
self.repeat = repeat
|
||||||
|
|
||||||
|
if height is not None and width is not None:
|
||||||
|
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||||
|
self.dynamic_resolution = False
|
||||||
|
elif height is None and width is None:
|
||||||
|
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
||||||
|
self.dynamic_resolution = True
|
||||||
|
|
||||||
|
if metadata_path is None:
|
||||||
|
print("No metadata. Trying to generate it.")
|
||||||
|
metadata = self.generate_metadata(base_path)
|
||||||
|
print(f"{len(metadata)} lines in metadata.")
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
elif metadata_path.endswith(".json"):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.data = metadata
|
||||||
|
else:
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_metadata(self, folder):
|
||||||
|
image_list, prompt_list = [], []
|
||||||
|
file_set = set(os.listdir(folder))
|
||||||
|
for file_name in file_set:
|
||||||
|
if "." not in file_name:
|
||||||
|
continue
|
||||||
|
file_ext_name = file_name.split(".")[-1].lower()
|
||||||
|
file_base_name = file_name[:-len(file_ext_name)-1]
|
||||||
|
if file_ext_name not in self.image_file_extension:
|
||||||
|
continue
|
||||||
|
prompt_file_name = file_base_name + ".txt"
|
||||||
|
if prompt_file_name not in file_set:
|
||||||
|
continue
|
||||||
|
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
||||||
|
prompt = f.read().strip()
|
||||||
|
image_list.append(file_name)
|
||||||
|
prompt_list.append(prompt)
|
||||||
|
metadata = pd.DataFrame()
|
||||||
|
metadata["image"] = image_list
|
||||||
|
metadata["prompt"] = prompt_list
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_height_width(self, image):
|
||||||
|
if self.dynamic_resolution:
|
||||||
|
width, height = image.size
|
||||||
|
if width * height > self.max_pixels:
|
||||||
|
scale = (width * height / self.max_pixels) ** 0.5
|
||||||
|
height, width = int(height / scale), int(width / scale)
|
||||||
|
height = height // self.height_division_factor * self.height_division_factor
|
||||||
|
width = width // self.width_division_factor * self.width_division_factor
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(self, file_path):
|
||||||
|
image = Image.open(file_path).convert("RGB")
|
||||||
|
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(self, file_path):
|
||||||
|
return self.load_image(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
data = self.data[data_id % len(self.data)].copy()
|
||||||
|
for key in self.data_file_keys:
|
||||||
|
if key in data:
|
||||||
|
path = os.path.join(self.base_path, data[key])
|
||||||
|
data[key] = self.load_data(path)
|
||||||
|
if data[key] is None:
|
||||||
|
warnings.warn(f"cannot load file {data[key]}.")
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data) * self.repeat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_path=None, metadata_path=None,
|
||||||
|
num_frames=81,
|
||||||
|
time_division_factor=4, time_division_remainder=1,
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
data_file_keys=("video",),
|
||||||
|
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||||
|
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
||||||
|
repeat=1,
|
||||||
|
args=None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
base_path = args.dataset_base_path
|
||||||
|
metadata_path = args.dataset_metadata_path
|
||||||
|
height = args.height
|
||||||
|
width = args.width
|
||||||
|
max_pixels = args.max_pixels
|
||||||
|
num_frames = args.num_frames
|
||||||
|
data_file_keys = args.data_file_keys.split(",")
|
||||||
|
repeat = args.dataset_repeat
|
||||||
|
|
||||||
|
self.base_path = base_path
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
self.data_file_keys = data_file_keys
|
||||||
|
self.image_file_extension = image_file_extension
|
||||||
|
self.video_file_extension = video_file_extension
|
||||||
|
self.repeat = repeat
|
||||||
|
|
||||||
|
if height is not None and width is not None:
|
||||||
|
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||||
|
self.dynamic_resolution = False
|
||||||
|
elif height is None and width is None:
|
||||||
|
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
||||||
|
self.dynamic_resolution = True
|
||||||
|
|
||||||
|
if metadata_path is None:
|
||||||
|
print("No metadata. Trying to generate it.")
|
||||||
|
metadata = self.generate_metadata(base_path)
|
||||||
|
print(f"{len(metadata)} lines in metadata.")
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
elif metadata_path.endswith(".json"):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.data = metadata
|
||||||
|
else:
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_metadata(self, folder):
|
||||||
|
video_list, prompt_list = [], []
|
||||||
|
file_set = set(os.listdir(folder))
|
||||||
|
for file_name in file_set:
|
||||||
|
if "." not in file_name:
|
||||||
|
continue
|
||||||
|
file_ext_name = file_name.split(".")[-1].lower()
|
||||||
|
file_base_name = file_name[:-len(file_ext_name)-1]
|
||||||
|
if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension:
|
||||||
|
continue
|
||||||
|
prompt_file_name = file_base_name + ".txt"
|
||||||
|
if prompt_file_name not in file_set:
|
||||||
|
continue
|
||||||
|
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
||||||
|
prompt = f.read().strip()
|
||||||
|
video_list.append(file_name)
|
||||||
|
prompt_list.append(prompt)
|
||||||
|
metadata = pd.DataFrame()
|
||||||
|
metadata["video"] = video_list
|
||||||
|
metadata["prompt"] = prompt_list
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_height_width(self, image):
|
||||||
|
if self.dynamic_resolution:
|
||||||
|
width, height = image.size
|
||||||
|
if width * height > self.max_pixels:
|
||||||
|
scale = (width * height / self.max_pixels) ** 0.5
|
||||||
|
height, width = int(height / scale), int(width / scale)
|
||||||
|
height = height // self.height_division_factor * self.height_division_factor
|
||||||
|
width = width // self.width_division_factor * self.width_division_factor
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_frames(self, reader):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
if int(reader.count_frames()) < num_frames:
|
||||||
|
num_frames = int(reader.count_frames())
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, file_path):
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
num_frames = self.get_num_frames(reader)
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(frame_id)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(self, file_path):
|
||||||
|
image = Image.open(file_path).convert("RGB")
|
||||||
|
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||||
|
frames = [image]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_image(self, file_path):
|
||||||
|
file_ext_name = file_path.split(".")[-1]
|
||||||
|
return file_ext_name.lower() in self.image_file_extension
|
||||||
|
|
||||||
|
|
||||||
|
def is_video(self, file_path):
|
||||||
|
file_ext_name = file_path.split(".")[-1]
|
||||||
|
return file_ext_name.lower() in self.video_file_extension
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(self, file_path):
|
||||||
|
if self.is_image(file_path):
|
||||||
|
return self.load_image(file_path)
|
||||||
|
elif self.is_video(file_path):
|
||||||
|
return self.load_video(file_path)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
data = self.data[data_id % len(self.data)].copy()
|
||||||
|
for key in self.data_file_keys:
|
||||||
|
if key in data:
|
||||||
|
path = os.path.join(self.base_path, data[key])
|
||||||
|
data[key] = self.load_data(path)
|
||||||
|
if data[key] is None:
|
||||||
|
warnings.warn(f"cannot load file {data[key]}.")
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data) * self.repeat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionTrainingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
for name, model in self.named_children():
|
||||||
|
model.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_modules(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
return trainable_modules
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_param_names(self):
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
return trainable_param_names
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
||||||
|
if lora_alpha is None:
|
||||||
|
lora_alpha = lora_rank
|
||||||
|
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||||
|
model = inject_adapter_in_model(lora_config, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||||
|
trainable_param_names = self.trainable_param_names()
|
||||||
|
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||||
|
if remove_prefix is not None:
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith(remove_prefix):
|
||||||
|
name = name[len(remove_prefix):]
|
||||||
|
state_dict_[name] = param
|
||||||
|
state_dict = state_dict_
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLogger:
|
||||||
|
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||||
|
self.output_path = output_path
|
||||||
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||||
|
self.state_dict_converter = state_dict_converter
|
||||||
|
|
||||||
|
|
||||||
|
def on_step_end(self, loss):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def launch_training_task(
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
gradient_accumulation_steps: int = 1,
|
||||||
|
):
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
|
||||||
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
|
for epoch_id in range(num_epochs):
|
||||||
|
for data in tqdm(dataloader):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = model(data)
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
model_logger.on_step_end(loss)
|
||||||
|
scheduler.step()
|
||||||
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
||||||
|
accelerator = Accelerator()
|
||||||
|
model, dataloader = accelerator.prepare(model, dataloader)
|
||||||
|
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
||||||
|
for data_id, data in enumerate(tqdm(dataloader)):
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = model.forward_preprocess(data)
|
||||||
|
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
||||||
|
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def wan_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||||
|
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||||
|
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||||
|
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||||
|
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||||
|
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||||
|
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||||
|
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||||
|
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||||
|
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||||
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def flux_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||||
|
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
|
||||||
|
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||||
|
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||||
|
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||||
|
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||||
|
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||||
|
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||||
|
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||||
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
|
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||||
|
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||||
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||||
|
return parser
|
||||||
@@ -1 +1,2 @@
|
|||||||
from .layers import *
|
from .layers import *
|
||||||
|
from .gradient_checkpointing import *
|
||||||
|
|||||||
34
diffsynth/vram_management/gradient_checkpointing.py
Normal file
34
diffsynth/vram_management/gradient_checkpointing.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs, **kwargs):
|
||||||
|
return module(*inputs, **kwargs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_checkpoint_forward(
|
||||||
|
model,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_output = model(*args, **kwargs)
|
||||||
|
return model_output
|
||||||
@@ -8,8 +8,33 @@ def cast_to(weight, dtype, device):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
class AutoWrappedModule(torch.nn.Module):
|
class AutoTorchModule(torch.nn.Module):
|
||||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def check_free_vram(self):
|
||||||
|
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||||
|
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
|
||||||
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state != 0:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state != 1:
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def keep(self):
|
||||||
|
if self.state != 2:
|
||||||
|
self.to(dtype=self.computation_dtype, device=self.computation_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedModule(AutoTorchModule):
|
||||||
|
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||||
self.offload_dtype = offload_dtype
|
self.offload_dtype = offload_dtype
|
||||||
@@ -18,28 +43,57 @@ class AutoWrappedModule(torch.nn.Module):
|
|||||||
self.onload_device = onload_device
|
self.onload_device = onload_device
|
||||||
self.computation_dtype = computation_dtype
|
self.computation_dtype = computation_dtype
|
||||||
self.computation_device = computation_device
|
self.computation_device = computation_device
|
||||||
|
self.vram_limit = vram_limit
|
||||||
self.state = 0
|
self.state = 0
|
||||||
|
|
||||||
def offload(self):
|
|
||||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
|
||||||
self.state = 0
|
|
||||||
|
|
||||||
def onload(self):
|
|
||||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
|
||||||
self.state = 1
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
if self.state == 2:
|
||||||
module = self.module
|
module = self.module
|
||||||
else:
|
else:
|
||||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||||
|
module = self.module
|
||||||
|
elif self.vram_limit is not None and self.check_free_vram():
|
||||||
|
self.keep()
|
||||||
|
module = self.module
|
||||||
|
else:
|
||||||
|
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||||
return module(*args, **kwargs)
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AutoWrappedLinear(torch.nn.Linear):
|
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
|
||||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
|
||||||
|
with init_weights_on_device(device=torch.device("meta")):
|
||||||
|
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||||
|
self.weight = module.weight
|
||||||
|
self.bias = module.bias
|
||||||
|
self.offload_dtype = offload_dtype
|
||||||
|
self.offload_device = offload_device
|
||||||
|
self.onload_dtype = onload_dtype
|
||||||
|
self.onload_device = onload_device
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.vram_limit = vram_limit
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
if self.state == 2:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
else:
|
||||||
|
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
elif self.vram_limit is not None and self.check_free_vram():
|
||||||
|
self.keep()
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
else:
|
||||||
|
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
|
with torch.amp.autocast(device_type=x.device.type):
|
||||||
|
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||||
|
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs):
|
||||||
with init_weights_on_device(device=torch.device("meta")):
|
with init_weights_on_device(device=torch.device("meta")):
|
||||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||||
self.weight = module.weight
|
self.weight = module.weight
|
||||||
@@ -50,29 +104,28 @@ class AutoWrappedLinear(torch.nn.Linear):
|
|||||||
self.onload_device = onload_device
|
self.onload_device = onload_device
|
||||||
self.computation_dtype = computation_dtype
|
self.computation_dtype = computation_dtype
|
||||||
self.computation_device = computation_device
|
self.computation_device = computation_device
|
||||||
|
self.vram_limit = vram_limit
|
||||||
self.state = 0
|
self.state = 0
|
||||||
|
self.name = name
|
||||||
def offload(self):
|
|
||||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
|
||||||
self.state = 0
|
|
||||||
|
|
||||||
def onload(self):
|
|
||||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
|
||||||
self.state = 1
|
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
if self.state == 2:
|
||||||
weight, bias = self.weight, self.bias
|
weight, bias = self.weight, self.bias
|
||||||
else:
|
else:
|
||||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
weight, bias = self.weight, self.bias
|
||||||
|
elif self.vram_limit is not None and self.check_free_vram():
|
||||||
|
self.keep()
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
else:
|
||||||
|
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
|
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||||
for source_module, target_module in module_map.items():
|
for source_module, target_module in module_map.items():
|
||||||
if isinstance(module, source_module):
|
if isinstance(module, source_module):
|
||||||
num_param = sum(p.numel() for p in module.parameters())
|
num_param = sum(p.numel() for p in module.parameters())
|
||||||
@@ -80,16 +133,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict,
|
|||||||
module_config_ = overflow_module_config
|
module_config_ = overflow_module_config
|
||||||
else:
|
else:
|
||||||
module_config_ = module_config
|
module_config_ = module_config
|
||||||
module_ = target_module(module, **module_config_)
|
module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name)
|
||||||
setattr(model, name, module_)
|
setattr(model, name, module_)
|
||||||
total_num_param += num_param
|
total_num_param += num_param
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name)
|
||||||
return total_num_param
|
return total_num_param
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
|
||||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
|
||||||
model.vram_management_enabled = True
|
model.vram_management_enabled = True
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf
|
|||||||
|
|
||||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
@@ -77,6 +77,11 @@ Demonstration of the styled entity control results with EliGen and IP-Adapter, s
|
|||||||
|-|-|-|-|
|
|-|-|-|-|
|
||||||
|||||
|
|||||
|
||||||
|
|
||||||
|
We also provide a demo of the styled entity control results with EliGen and specific styled lora, see [./styled_entity_control.py](./styled_entity_control.py) for details. Here is the visualization of EliGen with [Lego dreambooth lora](https://huggingface.co/merve/flux-lego-lora-dreambooth).
|
||||||
|
|||||
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
### Entity Transfer
|
### Entity Transfer
|
||||||
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
||||||
|
|
||||||
|
|||||||
@@ -27,11 +27,20 @@ def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
|||||||
|
|
||||||
# download and load model
|
# download and load model
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
# set download_from_modelscope = False if you want to download model from huggingface
|
||||||
|
download_from_modelscope = True
|
||||||
|
if download_from_modelscope:
|
||||||
|
model_id = "DiffSynth-Studio/Eligen"
|
||||||
|
downloading_priority = ["ModelScope"]
|
||||||
|
else:
|
||||||
|
model_id = "modelscope/EliGen"
|
||||||
|
downloading_priority = ["HuggingFace"]
|
||||||
model_manager.load_lora(
|
model_manager.load_lora(
|
||||||
download_customized_models(
|
download_customized_models(
|
||||||
model_id="DiffSynth-Studio/Eligen",
|
model_id=model_id,
|
||||||
origin_file_path="model_bf16.safetensors",
|
origin_file_path="model_bf16.safetensors",
|
||||||
local_dir="models/lora/entity_control"
|
local_dir="models/lora/entity_control",
|
||||||
|
downloading_priority=downloading_priority
|
||||||
),
|
),
|
||||||
lora_alpha=1
|
lora_alpha=1
|
||||||
)
|
)
|
||||||
|
|||||||
90
examples/EntityControl/styled_entity_control.py
Normal file
90
examples/EntityControl/styled_entity_control.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from examples.EntityControl.utils import visualize_masks
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||||
|
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
for seed in seeds:
|
||||||
|
# generate image
|
||||||
|
image = pipe(
|
||||||
|
prompt=global_prompt,
|
||||||
|
cfg_scale=3.0,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=50,
|
||||||
|
embedded_guidance=3.5,
|
||||||
|
seed=seed,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
eligen_entity_prompts=entity_prompts,
|
||||||
|
eligen_entity_masks=masks,
|
||||||
|
)
|
||||||
|
image.save(f"styled_eligen_example_{example_id}_{seed}.png")
|
||||||
|
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_example_{example_id}_mask_{seed}.png")
|
||||||
|
|
||||||
|
# download and load model
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
model_manager.load_lora(
|
||||||
|
download_customized_models(
|
||||||
|
model_id="FluxLora/merve-flux-lego-lora-dreambooth",
|
||||||
|
origin_file_path="pytorch_lora_weights.safetensors",
|
||||||
|
local_dir="models/lora/merve-flux-lego-lora-dreambooth"
|
||||||
|
),
|
||||||
|
lora_alpha=1
|
||||||
|
)
|
||||||
|
model_manager.load_lora(
|
||||||
|
download_customized_models(
|
||||||
|
model_id="DiffSynth-Studio/Eligen",
|
||||||
|
origin_file_path="model_bf16.safetensors",
|
||||||
|
local_dir="models/lora/entity_control"
|
||||||
|
),
|
||||||
|
lora_alpha=1
|
||||||
|
)
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
# example 1
|
||||||
|
trigger_word = "lego set in style of TOK, "
|
||||||
|
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||||
|
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 2
|
||||||
|
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||||
|
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 3
|
||||||
|
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||||
|
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 4
|
||||||
|
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||||
|
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 5
|
||||||
|
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||||
|
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 6
|
||||||
|
global_prompt = "Snow White and the 6 Dwarfs."
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||||
|
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 7, same prompt with different seeds
|
||||||
|
seeds = range(5, 9)
|
||||||
|
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||||
|
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||||
@@ -8,6 +8,12 @@
|
|||||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||||
|
|
||||||
|
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|
||||||
|
|VRAM required|Example script|Frames|Resolution|Note|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|
||||||
|
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||||
|
|
||||||
## Gallery
|
## Gallery
|
||||||
|
|
||||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||||
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
|||||||
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||||
|
|
||||||
|
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a
|
||||||
|
|||||||
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["HunyuanVideoI2V"])
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
|
# The DiT model is loaded in bfloat16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The other modules are loaded in float16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
# The computation device is "cuda".
|
||||||
|
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
enable_vram_management=True)
|
||||||
|
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||||
|
|
||||||
|
i2v_resolution = "720p"
|
||||||
|
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||||
|
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||||
|
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||||
|
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)
|
||||||
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["HunyuanVideoI2V"])
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
|
# The DiT model is loaded in bfloat16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The other modules are loaded in float16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
# The computation device is "cuda".
|
||||||
|
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
enable_vram_management=False)
|
||||||
|
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||||
|
pipe.enable_cpu_offload()
|
||||||
|
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||||
|
|
||||||
|
i2v_resolution = "720p"
|
||||||
|
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||||
|
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||||
|
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||||
|
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)
|
||||||
7
examples/InfiniteYou/README.md
Normal file
7
examples/InfiniteYou/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
|
||||||
|
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|
||||||
|
|
||||||
|
|Identity Image|Generated Image|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|||
|
||||||
58
examples/InfiniteYou/infiniteyou.py
Normal file
58
examples/InfiniteYou/infiniteyou.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
if importlib.util.find_spec("facexlib") is None:
|
||||||
|
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
|
||||||
|
if importlib.util.find_spec("insightface") is None:
|
||||||
|
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
|
||||||
|
|
||||||
|
download_models(["InfiniteYou"])
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
model_manager.load_models([
|
||||||
|
[
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||||
|
],
|
||||||
|
"models/InfiniteYou/image_proj_model.bin",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
controlnet_config_units=[
|
||||||
|
ControlNetConfigUnit(
|
||||||
|
processor_id="none",
|
||||||
|
model_path=[
|
||||||
|
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
|
||||||
|
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
|
||||||
|
],
|
||||||
|
scale=1.0
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
|
||||||
|
|
||||||
|
prompt = "A man, portrait, cinematic"
|
||||||
|
id_image = "data/examples/infiniteyou/man.jpg"
|
||||||
|
id_image = Image.open(id_image).convert('RGB')
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt, seed=1,
|
||||||
|
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||||
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
|
height=1024, width=1024,
|
||||||
|
)
|
||||||
|
image.save("man.jpg")
|
||||||
|
|
||||||
|
prompt = "A woman, portrait, cinematic"
|
||||||
|
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||||
|
id_image = Image.open(id_image).convert('RGB')
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt, seed=1,
|
||||||
|
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||||
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
|
height=1024, width=1024,
|
||||||
|
)
|
||||||
|
image.save("woman.jpg")
|
||||||
@@ -16,14 +16,14 @@ The IP-Adapter model based on Stable Diffusion XL is more powerful. You have the
|
|||||||
|
|
||||||
* Content controlling (original usage of IP-Adapter)
|
* Content controlling (original usage of IP-Adapter)
|
||||||
|
|
||||||
|First, we generate a rabbit.|Next, enable IP-Adapter and let the rabbit jump.|For comparision, disable IP-Adapter to see the generated image.|
|
|First, we generate a rabbit.|Next, enable IP-Adapter and let the rabbit jump.|For comparison, disable IP-Adapter to see the generated image.|
|
||||||
|-|-|-|
|
|-|-|-|
|
||||||
||||
|
||||
|
||||||
|
|
||||||
|
|
||||||
* Style controlling (InstantStyle)
|
* Style controlling (InstantStyle)
|
||||||
|
|
||||||
|First, we generate a rabbit.|Next, enable InstantStyle and convert the rabbit to a cat.|For comparision, disable IP-Adapter to see the generated image.|
|
|First, we generate a rabbit.|Next, enable InstantStyle and convert the rabbit to a cat.|For comparison, disable IP-Adapter to see the generated image.|
|
||||||
|-|-|-|
|
|-|-|-|
|
||||||
||||
|
||||
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth.prompters import WanXPrompter
|
|
||||||
from diffsynth.models.wanx_text_encoder import WanXTextEncoder
|
|
||||||
|
|
||||||
prompter = WanXPrompter('models/WanX/google/umt5-xxl')
|
|
||||||
text_encoder = WanXTextEncoder()
|
|
||||||
text_encoder.load_state_dict(torch.load('models/WanX/models_t5_umt5-xxl-enc-bf16.pth', map_location='cpu'))
|
|
||||||
text_encoder = text_encoder.eval().requires_grad_(False).to(dtype=torch.bfloat16, device='cuda')
|
|
||||||
|
|
||||||
prompter.fetch_models(text_encoder)
|
|
||||||
|
|
||||||
prompt = '维京战士双手挥舞着大斧,对抗猛犸象,黄昏,雪地中,漫天飞雪'
|
|
||||||
neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
|
||||||
|
|
||||||
prompt_emb = prompter.encode_prompt(prompt)
|
|
||||||
neg_prompt_emb = prompter.encode_prompt(neg_prompt)
|
|
||||||
print(prompt_emb[0]) # torch.Size([31, 4096])
|
|
||||||
print(neg_prompt_emb[0]) # torch.Size([126, 4096])
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import imageio
|
|
||||||
from diffsynth import ModelManager
|
|
||||||
|
|
||||||
def save_video(tensor,
|
|
||||||
save_file=None,
|
|
||||||
fps=30,
|
|
||||||
nrow=8,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1)):
|
|
||||||
|
|
||||||
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
||||||
tensor = torch.stack([
|
|
||||||
torchvision.utils.make_grid(
|
|
||||||
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
|
||||||
for u in tensor.unbind(2)
|
|
||||||
],
|
|
||||||
dim=1).permute(1, 2, 3, 0) #frame, h, w, 3
|
|
||||||
tensor = (tensor * 255).type(torch.uint8).cpu()
|
|
||||||
|
|
||||||
# write video
|
|
||||||
writer = imageio.get_writer(
|
|
||||||
save_file, fps=fps, codec='libx264', quality=8)
|
|
||||||
for frame in tensor.numpy():
|
|
||||||
writer.append_data(frame)
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
torch.cuda.memory._record_memory_history()
|
|
||||||
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.float, device="cuda")
|
|
||||||
model_manager.load_models([
|
|
||||||
"models/WanX/vae.pth",
|
|
||||||
])
|
|
||||||
|
|
||||||
vae = model_manager.fetch_model('wanxvideo_vae')
|
|
||||||
|
|
||||||
latents = [torch.load('sample.pt')]
|
|
||||||
videos = vae.decode(latents, device=latents[0].device, tiled=True)
|
|
||||||
back_encode = vae.encode(videos, device=latents[0].device, tiled=True)
|
|
||||||
|
|
||||||
videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False)
|
|
||||||
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
|
|
||||||
|
|
||||||
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
|
|
||||||
save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)
|
|
||||||
318
examples/flux/README.md
Normal file
318
examples/flux/README.md
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
# FLUX
|
||||||
|
|
||||||
|
[切换到中文](./README_zh.md)
|
||||||
|
|
||||||
|
FLUX is a series of image generation models open-sourced by Black-Forest-Labs.
|
||||||
|
|
||||||
|
**DiffSynth-Studio has introduced a new inference and training framework. If you need to use the old version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before using these models, please install DiffSynth-Studio from source code:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
You can quickly load the FLUX.1-dev model and perform inference by running the following code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(prompt="a cat", seed=0)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
**Support for the new framework of the FLUX series models is under active development. Stay tuned!**
|
||||||
|
|
||||||
|
| Model ID | Additional Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The following sections will help you understand our features and write inference code.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Loading Models</summary>
|
||||||
|
|
||||||
|
Models are loaded using `from_pretrained`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Here, `torch_dtype` and `device` refer to the computation precision and device, respectively. The `model_configs` can be configured in various ways to specify model paths:
|
||||||
|
|
||||||
|
* Download the model from [ModelScope Community](https://modelscope.cn/) and load it. In this case, provide `model_id` and `origin_file_pattern`, for example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||||
|
```
|
||||||
|
|
||||||
|
* Load the model from a local file path. In this case, provide the `path`, for example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||||
|
```
|
||||||
|
|
||||||
|
For models that consist of multiple files, use a list as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelConfig(path=[
|
||||||
|
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
The `from_pretrained` method also provides additional parameters to control model loading behavior:
|
||||||
|
|
||||||
|
* `local_model_path`: Path for saving downloaded models. The default is `"./models"`.
|
||||||
|
* `skip_download`: Whether to skip downloading models. The default is `False`. If your network cannot access [ModelScope Community](https://modelscope.cn/), manually download the required files and set this to `True`.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>VRAM Management</summary>
|
||||||
|
|
||||||
|
DiffSynth-Studio provides fine-grained VRAM management for FLUX models, enabling inference on devices with limited VRAM. You can enable offloading functionality via the following code, which moves certain modules to system memory on devices with limited GPU memory.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
```
|
||||||
|
|
||||||
|
The `enable_vram_management` function provides the following parameters to control VRAM usage:
|
||||||
|
|
||||||
|
* `vram_limit`: VRAM usage limit in GB. By default, it uses the remaining VRAM available on the device. Note that this is not an absolute limit; if the set VRAM is insufficient but more VRAM is actually available, the model will run with minimal VRAM consumption. Setting it to 0 achieves the theoretical minimum VRAM usage.
|
||||||
|
* `vram_buffer`: VRAM buffer size in GB. The default is 0.5GB. Since some large neural network layers may consume extra VRAM during onload phases, a VRAM buffer is necessary. Ideally, the optimal value should match the VRAM occupied by the largest layer in the model.
|
||||||
|
* `num_persistent_param_in_dit`: Number of persistent parameters in the DiT model (default: no limit). We plan to remove this parameter in the future, so please avoid relying on it.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Inference Acceleration</summary>
|
||||||
|
|
||||||
|
* TeaCache: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache), please refer to the [sample code](./acceleration/teacache.py).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Input Parameters</summary>
|
||||||
|
|
||||||
|
The pipeline accepts the following input parameters during inference:
|
||||||
|
|
||||||
|
* `prompt`: Prompt describing what should appear in the image.
|
||||||
|
* `negative_prompt`: Negative prompt describing what should **not** appear in the image. Default is `""`.
|
||||||
|
* `cfg_scale`: Classifier-free guidance scale. Default is 1. It becomes effective when set to a value greater than 1.
|
||||||
|
* `embedded_guidance`: Embedded guidance parameter for FLUX-dev. Default is 3.5.
|
||||||
|
* `t5_sequence_length`: Sequence length of T5 text embeddings. Default is 512.
|
||||||
|
* `input_image`: Input image used for image-to-image generation. This works together with `denoising_strength`.
|
||||||
|
* `denoising_strength`: Denoising strength, ranging from 0 to 1. Default is 1. When close to 0, the generated image will be similar to the input image; when close to 1, the generated image will differ significantly from the input. Do not set this to a non-1 value if no `input_image` is provided.
|
||||||
|
* `height`: Height of the generated image. Must be a multiple of 16.
|
||||||
|
* `width`: Width of the generated image. Must be a multiple of 16.
|
||||||
|
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||||
|
* `rand_device`: Device for generating random Gaussian noise. Default is `"cpu"`. Setting it to `"cuda"` may lead to different results across GPUs.
|
||||||
|
* `sigma_shift`: Parameter from Rectified Flow theory. Default is 3. A larger value increases the number of steps spent at the beginning of denoising and can improve image quality. However, it may cause inconsistencies between the generation process and training data.
|
||||||
|
* `num_inference_steps`: Number of inference steps. Default is 30.
|
||||||
|
* `kontext_images`: Input images for the Kontext model.
|
||||||
|
* `controlnet_inputs`: Inputs for the ControlNet model.
|
||||||
|
* `ipadapter_images`: Input images for the IP-Adapter model.
|
||||||
|
* `ipadapter_scale`: Control strength of the IP-Adapter model.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
FLUX series models are trained using a unified script [`./model_training/train.py`](./model_training/train.py).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Script Parameters</summary>
|
||||||
|
|
||||||
|
The script supports the following parameters:
|
||||||
|
|
||||||
|
* Dataset
|
||||||
|
* `--dataset_base_path`: Root path to the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the metadata file of the dataset.
|
||||||
|
* `--max_pixels`: Maximum pixel area, default is 1024*1024. When dynamic resolution is enabled, any image with a resolution larger than this value will be scaled down.。
|
||||||
|
* `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||||
|
* `--data_file_keys`: Keys in metadata for data files. Comma-separated.
|
||||||
|
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||||
|
* Models
|
||||||
|
* `--model_paths`: Paths to load models. JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Comma-separated.
|
||||||
|
* Training
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of training epochs.
|
||||||
|
* `--output_path`: Output path for saving checkpoints.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint filenames.
|
||||||
|
* Trainable Modules
|
||||||
|
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
|
||||||
|
* `--lora_base_model`: Which base model to apply LoRA on.
|
||||||
|
* `--lora_target_modules`: Which layers to apply LoRA on.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* Extra Inputs
|
||||||
|
* `--extra_inputs`: Additional model inputs. Comma-separated.
|
||||||
|
* VRAM Management
|
||||||
|
* `--use_gradient_checkpointing`: Whether to use gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of steps for gradient accumulation.
|
||||||
|
* Miscellaneous
|
||||||
|
* `--align_to_opensource_format`: Whether to align the FLUX DiT LoRA format with the open-source version. Only applicable to LoRA training for FLUX.1-dev and FLUX.1-Kontext-dev.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 1: Prepare Dataset</summary>
|
||||||
|
|
||||||
|
The dataset contains a series of files. We recommend organizing your dataset files as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
data/example_image_dataset/
|
||||||
|
├── metadata.csv
|
||||||
|
├── image1.jpg
|
||||||
|
└── image2.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
Here, `image1.jpg`, `image2.jpg` are training image data, and `metadata.csv` is the metadata list, for example:
|
||||||
|
|
||||||
|
```
|
||||||
|
image,prompt
|
||||||
|
image1.jpg,"a cat is sleeping"
|
||||||
|
image2.jpg,"a dog is running"
|
||||||
|
```
|
||||||
|
|
||||||
|
We have built a sample image dataset to help you test more conveniently. You can download this dataset using the following command:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset supports multiple image formats: `"jpg", "jpeg", "png", "webp"`.
|
||||||
|
|
||||||
|
The image resolution can be controlled via script parameters `--height` and `--width`. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, allowing training with the actual width and height of each image in the dataset.
|
||||||
|
|
||||||
|
**We strongly recommend using fixed-resolution training, as there may be load-balancing issues in multi-GPU training with dynamic resolution.**
|
||||||
|
|
||||||
|
When the model requires additional inputs—for instance, `kontext_images` required by the controllable model [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)—please add corresponding columns in the dataset, for example:
|
||||||
|
|
||||||
|
```
|
||||||
|
image,prompt,kontext_images
|
||||||
|
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
If additional inputs include image files, you need to specify the column names to parse using the `--data_file_keys` parameter. You can add more column names accordingly, e.g., `--data_file_keys "image,kontext_images"`.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 2: Load Model</summary>
|
||||||
|
|
||||||
|
Similar to the model loading logic during inference, you can directly configure the model to be loaded using its model ID. For example, during inference we load the model with the following configuration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Then during training, simply provide the following parameter to load the corresponding model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you prefer to load the model from local files, as in the inference example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Then during training, set it up as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
--model_paths '[
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
||||||
|
]' \
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 3: Configure Trainable Modules</summary>
|
||||||
|
|
||||||
|
The training framework supports both full-model training and LoRA-based fine-tuning. Below are some examples:
|
||||||
|
|
||||||
|
* Full training of the DiT module: `--trainable_models dit`
|
||||||
|
* Training a LoRA model on the DiT module: `--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||||
|
|
||||||
|
Additionally, since the training script loads multiple modules (text encoder, DiT, VAE), you need to remove prefixes when saving the model files. For example, when performing full DiT training or LoRA training on the DiT module, please set `--remove_prefix_in_ckpt pipe.dit.`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 4: Launch the Training Script</summary>
|
||||||
|
|
||||||
|
We have written specific training commands for each model. Please refer to the table at the beginning of this document for details.
|
||||||
|
|
||||||
|
</details>
|
||||||
327
examples/flux/README_zh.md
Normal file
327
examples/flux/README_zh.md
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
# FLUX
|
||||||
|
|
||||||
|
[Switch to English](./README.md)
|
||||||
|
|
||||||
|
FLUX 是由 Black-Forest-Labs 开源的一系列图像生成模型。
|
||||||
|
|
||||||
|
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
通过运行以下代码可以快速加载 FLUX.1-dev 模型并进行推理。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(prompt="a cat", seed=0)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
**FLUX 系列模型的全新框架支持正在开发中,敬请期待!**
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
以下部分将会帮助您理解我们的功能并编写推理代码。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>加载模型</summary>
|
||||||
|
|
||||||
|
模型通过 `from_pretrained` 加载:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
其中 `torch_dtype` 和 `device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径:
|
||||||
|
|
||||||
|
* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id` 和 `origin_file_pattern`,例如
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||||
|
```
|
||||||
|
|
||||||
|
* 从本地文件路径加载模型。此时需要填写 `path`,例如
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||||
|
```
|
||||||
|
|
||||||
|
对于从多个文件加载的单一模型,使用列表即可,例如
|
||||||
|
|
||||||
|
```python
|
||||||
|
ModelConfig(path=[
|
||||||
|
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为:
|
||||||
|
|
||||||
|
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
||||||
|
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>显存管理</summary>
|
||||||
|
|
||||||
|
DiffSynth-Studio 为 FLUX 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
```
|
||||||
|
|
||||||
|
`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况:
|
||||||
|
|
||||||
|
* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||||
|
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||||
|
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>推理加速</summary>
|
||||||
|
|
||||||
|
* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>输入参数</summary>
|
||||||
|
|
||||||
|
Pipeline 在推理阶段能够接收以下输入参数:
|
||||||
|
|
||||||
|
* `prompt`: 提示词,描述画面中出现的内容。
|
||||||
|
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||||
|
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于1的数值时生效。
|
||||||
|
* `embedded_guidance`: FLUX-dev 的内嵌引导参数,默认值为 3.5。
|
||||||
|
* `t5_sequence_length`: T5 模型的文本向量序列长度,默认值为 512。
|
||||||
|
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
|
||||||
|
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
|
||||||
|
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||||
|
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||||
|
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||||
|
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||||
|
* `sigma_shift`: Rectified Flow 理论中的参数,默认为 3。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的图像内容与训练数据存在差异。
|
||||||
|
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||||
|
* `kontext_images`: Kontext 模型的输入图像。
|
||||||
|
* `controlnet_inputs`: ControlNet 模型的输入。
|
||||||
|
* `ipadapter_images`: IP-Adapter 模型的输入图像。
|
||||||
|
* `ipadapter_scale`: IP-Adapter 模型的控制强度。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>脚本参数</summary>
|
||||||
|
|
||||||
|
脚本包含以下参数:
|
||||||
|
|
||||||
|
* 数据集
|
||||||
|
* `--dataset_base_path`: 数据集的根路径。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--max_pixels`: 最大像素面积,默认为 1024*1024,当启用动态分辨率时,任何分辨率大于这个数值的图片都会被缩小。
|
||||||
|
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||||
|
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||||
|
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* 模型
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||||
|
* 训练
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)数量。
|
||||||
|
* `--output_path`: 保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||||
|
* 可训练模块
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* 额外模型输入
|
||||||
|
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||||
|
* 显存管理
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 其他
|
||||||
|
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 FLUX.1-dev 和 FLUX.1-Kontext-dev 的 LoRA 训练生效。
|
||||||
|
|
||||||
|
|
||||||
|
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 1: 准备数据集</summary>
|
||||||
|
|
||||||
|
数据集包含一系列文件,我们建议您这样组织数据集文件:
|
||||||
|
|
||||||
|
```
|
||||||
|
data/example_image_dataset/
|
||||||
|
├── metadata.csv
|
||||||
|
├── image1.jpg
|
||||||
|
└── image2.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
其中 `image1.jpg`、`image2.jpg` 为训练用图像数据,`metadata.csv` 为元数据列表,例如
|
||||||
|
|
||||||
|
```
|
||||||
|
image,prompt
|
||||||
|
image1.jpg,"a cat is sleeping"
|
||||||
|
image2.jpg,"a dog is running"
|
||||||
|
```
|
||||||
|
|
||||||
|
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
数据集支持多种图片格式,`"jpg", "jpeg", "png", "webp"`。
|
||||||
|
|
||||||
|
图片的尺寸可通过脚本参数 `--height`、`--width` 控制。当 `--height` 和 `--width` 为空时将会开启动态分辨率,按照数据集中每个图像的实际宽高训练。
|
||||||
|
|
||||||
|
**我们强烈建议使用固定分辨率训练,因为在多卡训练中存在负载均衡问题。**
|
||||||
|
|
||||||
|
当模型需要额外输入时,例如具备控制能力的模型 [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) 所需的 `kontext_images`,请在数据集中补充相应的列,例如:
|
||||||
|
|
||||||
|
```
|
||||||
|
image,prompt,kontext_images
|
||||||
|
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
额外输入若包含图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`,同时启用 `--extra_inputs "kontext_images"`。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 2: 加载模型</summary>
|
||||||
|
|
||||||
|
类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
那么在训练时,填入以下参数即可加载对应的模型。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您希望从本地文件加载模型,例如推理时
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
|
||||||
|
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
那么训练时需设置为
|
||||||
|
|
||||||
|
```shell
|
||||||
|
--model_paths '[
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
|
||||||
|
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
||||||
|
]' \
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 3: 设置可训练模块</summary>
|
||||||
|
|
||||||
|
训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子:
|
||||||
|
|
||||||
|
* 全量训练 DiT 部分:`--trainable_models dit`
|
||||||
|
* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||||
|
|
||||||
|
此外,由于训练脚本中加载了多个模块(text encoder、dit、vae),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Step 4: 启动训练程序</summary>
|
||||||
|
|
||||||
|
我们为每一个模型编写了训练命令,请参考本文档开头的表格。
|
||||||
|
|
||||||
|
</details>
|
||||||
24
examples/flux/acceleration/teacache.py
Normal file
24
examples/flux/acceleration/teacache.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
|
||||||
|
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt, embedded_guidance=3.5, seed=0,
|
||||||
|
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
|
||||||
|
)
|
||||||
|
image.save(f"image_{tea_cache_l1_thresh}.png")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user