mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc5f151dfa | ||
|
|
5cd6ed0096 | ||
|
|
be84b35bfd | ||
|
|
d9fc30ffd0 | ||
|
|
8f59d00d9e | ||
|
|
3d8ff39aed | ||
|
|
b5c194df43 | ||
|
|
8680f92b60 | ||
|
|
05c97bc755 | ||
|
|
db88d60750 | ||
|
|
40c6da8075 | ||
|
|
3981b8084f | ||
|
|
9dfb7c1c37 | ||
|
|
9ed54c188e | ||
|
|
6a47a346b1 | ||
|
|
e3f8a576cf | ||
|
|
0aff733a92 | ||
|
|
9471bff8a4 | ||
|
|
3f8eea4687 | ||
|
|
b1b2d50c0d | ||
|
|
9c6607f78d | ||
|
|
2a4709e572 | ||
|
|
04f3fce3b0 | ||
|
|
be9c3524a5 | ||
|
|
c3d899dd48 | ||
|
|
6e03ee2a75 | ||
|
|
979a8814f1 | ||
|
|
8be4fad330 | ||
|
|
8113f95278 | ||
|
|
9ca6c646df | ||
|
|
466b37994e | ||
|
|
518c6d6ac3 | ||
|
|
9920b8d975 | ||
|
|
237daa2048 | ||
|
|
e9af28e6a3 | ||
|
|
996515c7ca | ||
|
|
c2ccc39e3c | ||
|
|
ad24b93431 | ||
|
|
bd5fc32d79 | ||
|
|
03cefe8f58 | ||
|
|
64339f7089 | ||
|
|
0b1704976a | ||
|
|
0af60b9c73 | ||
|
|
280f0eacc0 | ||
|
|
03cba5e59e | ||
|
|
fa0ea0e1a4 | ||
|
|
40d24b8907 | ||
|
|
1bf02f439f | ||
|
|
0489c62550 | ||
|
|
ad98602da3 | ||
|
|
fb12ac316a | ||
|
|
e9ec2f2706 | ||
|
|
00f294454b | ||
|
|
0465d940c7 | ||
|
|
2c549598d0 | ||
|
|
7d33082d70 |
29
.github/workflows/publish.yaml
vendored
Normal file
29
.github/workflows/publish.yaml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-20.04
|
||||
#if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
||||
145
README.md
145
README.md
@@ -1,67 +1,76 @@
|
||||
# DiffSynth Studio
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
|
||||
## Roadmap
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
|
||||
|
||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
||||
|
||||
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||
|
||||
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
|
||||
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- The source codes are released in this project.
|
||||
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
|
||||
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||
|
||||
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
|
||||
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
- Demo videos are shown on Bilibili, including three tasks.
|
||||
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
|
||||
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
- Since OLSS requires additional training, we don't implement it in this project.
|
||||
|
||||
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
|
||||
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
|
||||
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
* Since OLSS requires additional training, we don't implement it in this project.
|
||||
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
|
||||
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
* Demo videos are shown on Bilibili, including three tasks.
|
||||
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
|
||||
* The source codes are released in this project.
|
||||
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
|
||||
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||
* Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
## Installation
|
||||
|
||||
Create Python environment:
|
||||
|
||||
```
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
|
||||
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
||||
|
||||
Enter the Python environment:
|
||||
|
||||
```
|
||||
conda activate DiffSynthStudio
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Usage (in Python code)
|
||||
@@ -76,15 +85,17 @@ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5
|
||||
|
||||
### Image Synthesis
|
||||
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
||||
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|Model|Example|
|
||||
|-|-|
|
||||
|||
|
||||
|Stable Diffusion||
|
||||
|Stable Diffusion XL||
|
||||
|Stable Diffusion 3||
|
||||
|Kolors||
|
||||
|Hunyuan-DiT||
|
||||
|
||||
### Toon Shading
|
||||
|
||||
@@ -100,22 +111,6 @@ Video stylization without video models. [`examples/diffsynth`](./examples/diffsy
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Chinese Models
|
||||
|
||||
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
|
||||
|
||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||
|
||||
|1024x1024|2048x2048 (highres-fix)|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Usage (in WebUI)
|
||||
|
||||
```
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .prompts import *
|
||||
from .prompters import *
|
||||
from .schedulers import *
|
||||
from .pipelines import *
|
||||
from .controlnets import *
|
||||
|
||||
0
diffsynth/configs/__init__.py
Normal file
0
diffsynth/configs/__init__.py
Normal file
243
diffsynth/configs/model_config.py
Normal file
243
diffsynth/configs/model_config.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
from ..models.sd_text_encoder import SDTextEncoder
|
||||
from ..models.sd_unet import SDUNet
|
||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
||||
|
||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from ..models.sdxl_unet import SDXLUNet
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from ..models.sd3_dit import SD3DiT
|
||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from ..models.sd_controlnet import SDControlNet
|
||||
|
||||
from ..models.sd_motion import SDMotionModel
|
||||
from ..models.sdxl_motion import SDXLMotionModel
|
||||
|
||||
from ..models.svd_image_encoder import SVDImageEncoder
|
||||
from ..models.svd_unet import SVDUNet
|
||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
|
||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
||||
]
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
"HunyuanDiT": [
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
}
|
||||
preset_models_on_modelscope = {
|
||||
# Hunyuan DiT
|
||||
"HunyuanDiT": [
|
||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
# Stable Video Diffusion
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# ExVideo
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"DreamShaper_8": [
|
||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"AingDiffusion_v12": [
|
||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"Flat2DAnimerge_v45Sharp": [
|
||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
# Textual Inversion
|
||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||
],
|
||||
# Stable Diffusion XL
|
||||
"StableDiffusionXL_v1": [
|
||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"BluePencilXL_v200": [
|
||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
"StableDiffusion3_without_T5": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
# ControlNet
|
||||
"ControlNet_v11f1p_sd15_depth": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11p_sd15_softedge": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11f1e_sd15_tile": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||
],
|
||||
"ControlNet_v11p_sd15_lineart": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
"AnimateDiff_xl_beta": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": [
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
# Translator
|
||||
"opus-mt-zh-en": [
|
||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||
],
|
||||
"IP-Adapter-SDXL": [
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
"stable-video-diffusion-img2vid-xt",
|
||||
"ExVideo-SVD-128f-v1",
|
||||
"StableDiffusion_v15",
|
||||
"DreamShaper_8",
|
||||
"AingDiffusion_v12",
|
||||
"Flat2DAnimerge_v45Sharp",
|
||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
||||
"StableDiffusionXL_v1",
|
||||
"BluePencilXL_v200",
|
||||
"StableDiffusionXL_Turbo",
|
||||
"ControlNet_v11f1p_sd15_depth",
|
||||
"ControlNet_v11p_sd15_softedge",
|
||||
"ControlNet_v11f1e_sd15_tile",
|
||||
"ControlNet_v11p_sd15_lineart",
|
||||
"AnimateDiff_v2",
|
||||
"AnimateDiff_xl_beta",
|
||||
"RIFE",
|
||||
"BeautifulPrompt",
|
||||
"opus-mt-zh-en",
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
"SDXL-vae-fp16-fix",
|
||||
]
|
||||
@@ -12,24 +12,24 @@ Processor_id: TypeAlias = Literal[
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
|
||||
35
diffsynth/data/simple_text_image.py
Normal file
35
diffsynth/data/simple_text_image.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch, os
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
@@ -99,7 +99,8 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
return flow_list, mask_list[2], merged
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return IFNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
0
diffsynth/extensions/__init__.py
Normal file
0
diffsynth/extensions/__init__.py
Normal file
@@ -1,482 +1 @@
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .sd_lora import SDLoRA
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = {}
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_stable_video_diffusion(self, state_dict):
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_RIFE(self, state_dict):
|
||||
param_name = "block_tea.convblock3.0.1.weight"
|
||||
return param_name in state_dict or ("module." + param_name) in state_dict
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||
return param_name in state_dict or param_name_2 in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff_xl(self, state_dict):
|
||||
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_sd_lora(self, state_dict):
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_translator(self, state_dict):
|
||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||
return param_name in state_dict and len(state_dict) == 254
|
||||
|
||||
def is_ipadapter(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
||||
|
||||
def is_ipadapter_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
||||
return param_name in state_dict and len(state_dict) == 521
|
||||
|
||||
def is_ipadapter_xl(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
||||
|
||||
def is_ipadapter_xl_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||
return param_name in state_dict and len(state_dict) == 777
|
||||
|
||||
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
||||
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit(self, state_dict):
|
||||
param_name = "final_layer.adaLN_modulation.1.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_diffusers_vae(self, state_dict):
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||
param_name = "blocks.185.positional_embedding.embeddings"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
"unet": SVDUNet,
|
||||
"vae_decoder": SVDVAEDecoder,
|
||||
"vae_encoder": SVDVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
if component == "unet":
|
||||
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
"unet": SDUNet,
|
||||
"vae_decoder": SDVAEDecoder,
|
||||
"vae_encoder": SDVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
if component == "text_encoder":
|
||||
# Add additional token embeddings to text encoder
|
||||
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||
for keyword in self.textual_inversion_dict:
|
||||
_, embeddings = self.textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDXLTextEncoder,
|
||||
"text_encoder_2": SDXLTextEncoder2,
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component in ["vae_decoder", "vae_encoder"]:
|
||||
# These two model will output nan when float16 is enabled.
|
||||
# The precision problem happens in the last three resnet blocks.
|
||||
# I do not know how to solve this problem.
|
||||
self.model[component].to(torch.float32).to(self.device)
|
||||
else:
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_controlnet(self, state_dict, file_path=""):
|
||||
component = "controlnet"
|
||||
if component not in self.model:
|
||||
self.model[component] = []
|
||||
self.model_path[component] = []
|
||||
model = SDControlNet()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component].append(model)
|
||||
self.model_path[component].append(file_path)
|
||||
|
||||
def load_animatediff(self, state_dict, file_path=""):
|
||||
component = "motion_modules"
|
||||
model = SDMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_animatediff_xl(self, state_dict, file_path=""):
|
||||
component = "motion_modules_xl"
|
||||
model = SDXLMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_RIFE(self, state_dict, file_path=""):
|
||||
component = "RIFE"
|
||||
from ..extensions.RIFE import IFNet
|
||||
model = IFNet().eval()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_sd_lora(self, state_dict, alpha):
|
||||
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||
|
||||
def load_translator(self, state_dict, file_path=""):
|
||||
# This model is lightweight, we do not place it on GPU.
|
||||
component = "translator"
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter(self, state_dict, file_path=""):
|
||||
component = "ipadapter"
|
||||
model = SDIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_image_encoder"
|
||||
model = IpAdapterCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl"
|
||||
model = SDXLIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl_image_encoder"
|
||||
model = IpAdapterXLCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_clip_text_encoder"
|
||||
model = HunyuanDiTCLIPTextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_t5_text_encoder"
|
||||
model = HunyuanDiTT5TextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit"
|
||||
model = HunyuanDiT()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_diffusers_vae(self, state_dict, file_path=""):
|
||||
# TODO: detect SD and SDXL
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||
unet_state_dict = self.model["unet"].state_dict()
|
||||
self.model["unet"].to("cpu")
|
||||
del self.model["unet"]
|
||||
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += self.search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
def load_textual_inversions(self, folder):
|
||||
# Store additional tokens here
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
if file_name.endswith(".txt"):
|
||||
continue
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
|
||||
# Search for embeddings
|
||||
for embeddings in self.search_for_embeddings(state_dict):
|
||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff_xl(state_dict):
|
||||
self.load_animatediff_xl(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
elif self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_sd_lora(state_dict):
|
||||
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
elif self.is_RIFE(state_dict):
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
elif self.is_translator(state_dict):
|
||||
self.load_translator(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter(state_dict):
|
||||
self.load_ipadapter(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_image_encoder(state_dict):
|
||||
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl(state_dict):
|
||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit(state_dict):
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, lora_alphas=lora_alphas)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
if isinstance(self.model[component], list):
|
||||
for model in self.model[component]:
|
||||
model.to(device)
|
||||
else:
|
||||
self.model[component].to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_model_with_model_path(self, model_path):
|
||||
for component in self.model_path:
|
||||
if isinstance(self.model_path[component], str):
|
||||
if os.path.samefile(self.model_path[component], model_path):
|
||||
return self.model[component]
|
||||
elif isinstance(self.model_path[component], list):
|
||||
for i, model_path_ in enumerate(self.model_path[component]):
|
||||
if os.path.samefile(model_path_, model_path):
|
||||
return self.model[component][i]
|
||||
raise ValueError(f"Please load model {model_path} before you use it.")
|
||||
|
||||
def __getattr__(self, __name):
|
||||
if __name in self.model:
|
||||
return self.model[__name]
|
||||
else:
|
||||
return super.__getattribute__(__name)
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
from .model_manager import *
|
||||
|
||||
66
diffsynth/models/downloader.py
Normal file
66
diffsynth/models/downloader.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
|
||||
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
"ModelScope",
|
||||
]
|
||||
website_to_preset_models = {
|
||||
"HuggingFace": preset_models_on_huggingface,
|
||||
"ModelScope": preset_models_on_modelscope,
|
||||
}
|
||||
website_to_download_fn = {
|
||||
"HuggingFace": download_from_huggingface,
|
||||
"ModelScope": download_from_modelscope,
|
||||
}
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
@@ -1,5 +1,4 @@
|
||||
from .attention import Attention
|
||||
from .tiler import TileWorker
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
@@ -399,7 +398,8 @@ class HunyuanDiT(torch.nn.Module):
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -79,7 +79,8 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
@@ -131,7 +132,8 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
1363
diffsynth/models/kolors_text_encoder.py
Normal file
1363
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
195
diffsynth/models/lora.py
Normal file
195
diffsynth/models/lora.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNet
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
|
||||
class LoRAFromCivitai:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = []
|
||||
self.lora_prefix = []
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||
for special_key in self.special_keys:
|
||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||
if model_resource == "diffusers":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||
elif model_resource == "civitai":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers", "civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return lora_prefix, model_resource
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
}
|
||||
|
||||
|
||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||
}
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
for name, param in state_dict_model.items():
|
||||
torch_dtype = param.dtype
|
||||
device = param.device
|
||||
break
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
536
diffsynth/models/model_manager.py
Normal file
536
diffsynth/models/model_manager.py
Normal file
@@ -0,0 +1,536 @@
|
||||
import os, torch, hashlib, json, importlib
|
||||
from safetensors import safe_open
|
||||
from torch import Tensor
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from .sd3_dit import SD3DiT
|
||||
from .sd3_vae_decoder import SD3VAEDecoder
|
||||
from .sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_for_embeddings(state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
|
||||
|
||||
def search_for_files(folder, extensions):
|
||||
files = []
|
||||
if os.path.isdir(folder):
|
||||
for file in sorted(os.listdir(folder)):
|
||||
files += search_for_files(os.path.join(folder, file), extensions)
|
||||
elif os.path.isfile(folder):
|
||||
for extension in extensions:
|
||||
if folder.endswith(extension):
|
||||
files.append(folder)
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||
state_dict_converter = model_class.state_dict_converter()
|
||||
if model_resource == "civitai":
|
||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||
elif model_resource == "diffusers":
|
||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||
if isinstance(state_dict_results, tuple):
|
||||
model_state_dict, extra_kwargs = state_dict_results
|
||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
||||
model.load_state_dict(model_state_dict)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
model = model.to(device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||
base_state_dict = base_model.state_dict()
|
||||
base_model.to("cpu")
|
||||
del base_model
|
||||
model = model_class(**extra_kwargs)
|
||||
model.load_state_dict(base_state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
while True:
|
||||
for model_id in range(len(model_manager.model)):
|
||||
base_model_name = model_manager.model_name[model_id]
|
||||
if base_model_name == model_name:
|
||||
base_model_path = model_manager.model_path[model_id]
|
||||
base_model = model_manager.model[model_id]
|
||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||
patched_model = load_single_patch_model_from_single_file(
|
||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||
loaded_model_names.append(base_model_name)
|
||||
loaded_models.append(patched_model)
|
||||
model_manager.model.pop(model_id)
|
||||
model_manager.model_path.pop(model_id)
|
||||
model_manager.model_name.pop(model_id)
|
||||
break
|
||||
else:
|
||||
break
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorTemplate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
return False
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
self.keys_hash_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||
if keys_hash is not None:
|
||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
# Load models without strict matching
|
||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
super().__init__(model_loader_configs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
# Split the state_dict and load from each component
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
valid_state_dict = {}
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
valid_state_dict.update(sub_state_dict)
|
||||
if super().match(file_path, valid_state_dict):
|
||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
else:
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromHuggingfaceFolder:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.architecture_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
return False
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if "architectures" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for architecture in config["architectures"]:
|
||||
huggingface_lib, model_name = self.architecture_dict[architecture]
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromPatchedSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
loaded_model_names, loaded_models = [], []
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda",
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
file_path_list: List[str] = [],
|
||||
):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = []
|
||||
self.model_path = []
|
||||
self.model_name = []
|
||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
||||
self.model_detector = [
|
||||
ModelDetectorFromSingleFile(model_loader_configs),
|
||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||
]
|
||||
self.load_models(downloaded_files + file_path_list)
|
||||
|
||||
|
||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||
print(f"Loading models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||
print(f"Loading models from folder: {file_path}")
|
||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||
print(f"Loading patch models from file: {file_path}")
|
||||
model_names, models = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following patched models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None):
|
||||
print(f"Loading models from: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
for model_detector in self.model_detector:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=self.device, torch_dtype=self.torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
break
|
||||
else:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if file_path is not None and file_path != model_path:
|
||||
continue
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available.")
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
else:
|
||||
return fetched_models[0]
|
||||
|
||||
|
||||
def to(self, device):
|
||||
for model in self.model:
|
||||
model.to(device)
|
||||
|
||||
798
diffsynth/models/sd3_dit.py
Normal file
798
diffsynth/models/sd3_dit.py
Normal file
@@ -0,0 +1,798 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from .svd_unet import TemporalTimesteps
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
||||
return spatial_pos_embed
|
||||
|
||||
def forward(self, latent):
|
||||
height, width = latent.shape[-2:]
|
||||
latent = self.proj(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
return latent + pos_embed
|
||||
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
return time_emb
|
||||
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
|
||||
class JointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.only_out_a = only_out_a
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerFinalBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class SD3DiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
|
||||
self.time_embedder = TimestepEmbeddings(256, 1536)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
|
||||
self.context_embedder = torch.nn.Linear(4096, 1536)
|
||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
|
||||
self.norm_out = AdaLayerNorm(1536, single=True)
|
||||
self.proj_out = torch.nn.Linear(1536, 64)
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.pos_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SD3DiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class SD3DiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
||||
"pos_embed.proj": "pos_embedder.proj",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "norm_out.linear",
|
||||
"proj_out": "proj_out",
|
||||
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "pos_embed.pos_embed":
|
||||
param = param.reshape((1, 192, 192, 1536))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in rename_dict:
|
||||
state_dict_[rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
|
||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
||||
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
||||
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
||||
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
||||
elif name == "model.diffusion_model.pos_embed":
|
||||
param = param.reshape((1, 192, 192, 1536))
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
1107
diffsynth/models/sd3_text_encoder.py
Normal file
1107
diffsynth/models/sd3_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
81
diffsynth/models/sd3_vae_decoder.py
Normal file
81
diffsynth/models/sd3_vae_decoder.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class SD3VAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
95
diffsynth/models/sd3_vae_encoder.py
Normal file
95
diffsynth/models/sd3_vae_encoder.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SD3VAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states[:, :16]
|
||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
@@ -99,7 +99,7 @@ class SDControlNet(torch.nn.Module):
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
@@ -134,7 +134,8 @@ class SDControlNet(torch.nn.Module):
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDControlNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class SDIpAdapter(torch.nn.Module):
|
||||
|
||||
def set_less_adapter(self):
|
||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
||||
self.set_full_adapter(self)
|
||||
self.set_full_adapter()
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
@@ -47,7 +47,8 @@ class SDIpAdapter(torch.nn.Module):
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
||||
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
||||
|
||||
|
||||
class SDLoRA:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
||||
special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
||||
for special_key in special_keys:
|
||||
target_name = target_name.replace(special_key, special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_unet = unet.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
||||
unet.load_state_dict(state_dict_unet)
|
||||
|
||||
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_text_encoder = text_encoder.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
||||
text_encoder.load_state_dict(state_dict_text_encoder)
|
||||
|
||||
@@ -144,7 +144,8 @@ class SDMotionModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -71,7 +71,8 @@ class SDTextEncoder(torch.nn.Module):
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
|
||||
# 2. pre-process
|
||||
@@ -342,7 +342,8 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -110,10 +112,12 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,8 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -71,6 +73,7 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -91,7 +94,8 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -96,7 +96,8 @@ class SDXLIpAdapter(torch.nn.Module):
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ class SDXLMotionModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,8 @@ class SDXLTextEncoder(torch.nn.Module):
|
||||
break
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
@@ -80,7 +81,8 @@ class SDXLTextEncoder2(torch.nn.Module):
|
||||
pooled_embeds = self.text_projection(pooled_embeds)
|
||||
return pooled_embeds, hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
||||
|
||||
|
||||
class SDXLUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, is_kolors=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
@@ -85,10 +86,12 @@ class SDXLUNet(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||
tiled=False, tile_size=64, tile_stride=8,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(add_time_id)
|
||||
@@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module):
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -119,7 +133,8 @@ class SDXLUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -148,6 +163,8 @@ class SDXLUNetStateDictConverter:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif names[0] in ["encoder_hid_proj"]:
|
||||
names[0] = "text_intermediate_proj"
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
@@ -181,7 +198,10 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
@@ -1873,4 +1893,7 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
@@ -2,14 +2,23 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class SDXLVAEDecoder(SDVAEDecoder):
|
||||
def __init__(self):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
@@ -2,14 +2,23 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SDXLVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
@@ -44,7 +44,8 @@ class SVDImageEncoder(torch.nn.Module):
|
||||
embeds = self.visual_projection(embeds)
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDImageEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -407,7 +407,8 @@ class SVDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDUNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -199,7 +199,8 @@ class SVDVAEDecoder(torch.nn.Module):
|
||||
return values
|
||||
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ class SVDVAEEncoder(SDVAEEncoder):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||
from .sd_image import SDImagePipeline
|
||||
from .sd_video import SDVideoPipeline
|
||||
from .sdxl_image import SDXLImagePipeline
|
||||
from .sdxl_video import SDXLVideoPipeline
|
||||
from .sd3_image import SD3ImagePipeline
|
||||
from .hunyuan_image import HunyuanDiTImagePipeline
|
||||
from .svd_video import SVDVideoPipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
34
diffsynth/pipelines/base.py
Normal file
34
diffsynth/pipelines/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_images(self, images):
|
||||
return [self.preprocess_image(image) for image in images]
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output):
|
||||
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output):
|
||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
||||
return video
|
||||
|
||||
@@ -22,6 +22,10 @@ def lets_dance(
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 0. Text embedding alignment (only for video processing)
|
||||
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
||||
|
||||
# 1. ControlNet
|
||||
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
||||
# I leave it here because I intend to do something interesting on the ControlNets.
|
||||
@@ -50,7 +54,7 @@ def lets_dance(
|
||||
additional_res_stack = None
|
||||
|
||||
# 2. time
|
||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = unet.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = unet.time_embedding(time_emb)
|
||||
|
||||
# 3. pre-process
|
||||
@@ -133,7 +137,7 @@ def lets_dance_xl(
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 2. time
|
||||
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = unet.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = unet.time_embedding(t_emb)
|
||||
|
||||
time_embeds = unet.add_time_proj(add_time_id)
|
||||
@@ -147,7 +151,7 @@ def lets_dance_xl(
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
|
||||
@@ -3,11 +3,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models import ModelManager
|
||||
from ..prompts import HunyuanDiTPrompter
|
||||
from ..prompters import HunyuanDiTPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -122,14 +122,12 @@ class ImageSizeManager:
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
class HunyuanDiTImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||
self.prompter = HunyuanDiTPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.image_size_manager = ImageSizeManager()
|
||||
# models
|
||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||
@@ -139,42 +137,60 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
|
||||
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
|
||||
self.dit = model_manager.hunyuan_dit
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
|
||||
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
|
||||
self.dit = model_manager.fetch_model("hunyuan_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
pipe = HunyuanDiTImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
|
||||
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
|
||||
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=positive,
|
||||
device=self.device
|
||||
)
|
||||
return {
|
||||
"text_emb": text_emb,
|
||||
"text_emb_mask": text_emb_mask,
|
||||
"text_emb_t5": text_emb_t5,
|
||||
"text_emb_mask_t5": text_emb_mask_t5
|
||||
}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
|
||||
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
|
||||
if tiled:
|
||||
height, width = tile_size * 16, tile_size * 16
|
||||
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
||||
@@ -198,7 +214,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
input_image=None,
|
||||
reference_images=[],
|
||||
reference_strengths=[0.4],
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
@@ -216,71 +231,32 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise.clone()
|
||||
|
||||
# Prepare reference latents
|
||||
reference_latents = []
|
||||
for reference_image in reference_images:
|
||||
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=True,
|
||||
device=self.device
|
||||
)
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=False,
|
||||
device=self.device
|
||||
)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
|
||||
# Prepare positional id
|
||||
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
|
||||
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# In-context reference
|
||||
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
|
||||
if progress_id < num_inference_steps * reference_strength:
|
||||
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
|
||||
self.dit(
|
||||
noisy_reference_latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
to_cache=True
|
||||
)
|
||||
# Positive side
|
||||
noise_pred_posi = self.dit(
|
||||
latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
latents, timestep=timestep, **prompt_emb_posi, **extra_input,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = self.dit(
|
||||
latents,
|
||||
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
|
||||
timestep,
|
||||
**extra_input
|
||||
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
|
||||
)
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
@@ -293,6 +269,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
105
diffsynth/pipelines/pipeline_runner.py
Normal file
105
diffsynth/pipelines/pipeline_runner.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import os, torch, json
|
||||
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
|
||||
from ..processors.sequencial_processor import SequencialProcessor
|
||||
from ..data import VideoData, save_frames, save_video
|
||||
|
||||
|
||||
|
||||
class SDVideoPipelineRunner:
|
||||
def __init__(self, in_streamlit=False):
|
||||
self.in_streamlit = in_streamlit
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_models(model_list)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
textual_inversion_paths = []
|
||||
for file_name in os.listdir(textual_inversion_folder):
|
||||
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
|
||||
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
|
||||
pipe.prompter.load_textual_inversions(textual_inversion_paths)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def load_smoother(self, model_manager, smoother_configs):
|
||||
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
||||
return smoother
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
else:
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
if start_frame_id is None:
|
||||
start_frame_id = 0
|
||||
if end_frame_id is None:
|
||||
end_frame_id = len(video)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||
if "smoother_configs" in config:
|
||||
if self.in_streamlit: st.markdown("Loading smoother ...")
|
||||
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
||||
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
||||
else:
|
||||
smoother = None
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
if self.in_streamlit: st.video(video_file.read())
|
||||
132
diffsynth/pipelines/sd3_image.py
Normal file
132
diffsynth/pipelines/sd3_image.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
|
||||
from ..prompters import SD3Prompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class SD3ImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.prompter = SD3Prompter()
|
||||
# models
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: SD3TextEncoder2 = None
|
||||
self.text_encoder_3: SD3TextEncoder3 = None
|
||||
self.dit: SD3DiT = None
|
||||
self.vae_decoder: SD3VAEDecoder = None
|
||||
self.vae_encoder: SD3VAEEncoder = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
|
||||
if "sd3_text_encoder_3" in model_manager.model:
|
||||
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
|
||||
self.dit = model_manager.fetch_model("sd3_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
pipe = SD3ImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
||||
)
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,23 +1,22 @@
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..prompters import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .base import BasePipeline
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDImagePipeline(torch.nn.Module):
|
||||
|
||||
class SDImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
@@ -28,61 +27,65 @@ class SDImagePipeline(torch.nn.Module):
|
||||
self.ipadapter: SDIpAdapter = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
def denoising_model(self):
|
||||
return self.unet
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sd_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
Annotator(config.processor_id, device=self.device),
|
||||
model_manager.fetch_model("sd_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||
if "ipadapter" in model_manager.model:
|
||||
self.ipadapter = model_manager.ipadapter
|
||||
if "ipadapter_image_encoder" in model_manager.model:
|
||||
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
pipe.fetch_ipadapter(model_manager)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, positive=True):
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
|
||||
return {"encoder_hidden_states": prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -104,53 +107,56 @@ class SDImagePipeline(torch.nn.Module):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||
device=self.device, vram_limit_level=0
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_nega = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||
device=self.device, vram_limit_level=0
|
||||
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
266
diffsynth/pipelines/sd_video.py
Normal file
266
diffsynth/pipelines/sd_video.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .sd_image import SDImagePipeline
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
ipadapter_kwargs_list = {},
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
device="cuda",
|
||||
animatediff_batch_size=16,
|
||||
animatediff_stride=8,
|
||||
):
|
||||
num_frames = sample.shape[0]
|
||||
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
||||
|
||||
for batch_id in range(0, num_frames, animatediff_stride):
|
||||
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
||||
|
||||
# process this batch
|
||||
hidden_states_batch = lets_dance(
|
||||
unet, motion_modules, controlnet,
|
||||
sample[batch_id: batch_id_].to(device),
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list,
|
||||
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
|
||||
).cpu()
|
||||
|
||||
# update hidden_states
|
||||
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
||||
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
||||
hidden_states, num = hidden_states_output[i]
|
||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||
hidden_states_output[i] = (hidden_states, num + bias)
|
||||
|
||||
if batch_id_ == num_frames:
|
||||
break
|
||||
|
||||
# output
|
||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class SDVideoPipeline(SDImagePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
||||
self.prompter = SDPrompter()
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
self.vae_decoder: SDVAEDecoder = None
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
||||
self.ipadapter: SDIpAdapter = None
|
||||
self.motion_modules: SDMotionModel = None
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sd_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device),
|
||||
model_manager.fetch_model("sd_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
||||
|
||||
# Motion Modules
|
||||
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
|
||||
if self.motion_modules is None:
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents.append(latent.cpu())
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters, batch size ...
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
other_kwargs = {
|
||||
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
|
||||
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
|
||||
"cross_frame_attention": cross_frame_attention,
|
||||
}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_video(input_frames, **tiler_kwargs)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||
rendered_frames = self.decode_video(rendered_frames)
|
||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||
target_latents = self.encode_video(rendered_frames)
|
||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
output_frames = self.decode_video(latents, **tiler_kwargs)
|
||||
|
||||
# Post-process
|
||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||
|
||||
return output_frames
|
||||
191
diffsynth/pipelines/sdxl_image.py
Normal file
191
diffsynth/pipelines/sdxl_image.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDXLPrompter, KolorsPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .base import BasePipeline
|
||||
from .dancer import lets_dance_xl
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDXLPrompter()
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.text_encoder_kolors: ChatGLMModel = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.unet
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
||||
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sdxl_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
|
||||
# ControlNets (TODO)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
||||
|
||||
# Kolors
|
||||
if self.text_encoder_kolors is not None:
|
||||
print("Switch to Kolors. The prompter and scheduler will be replaced.")
|
||||
self.prompter = KolorsPrompter()
|
||||
self.prompter.fetch_models(self.text_encoder_kolors)
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
||||
else:
|
||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
|
||||
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=positive,
|
||||
)
|
||||
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
||||
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_use_instant_style=False,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
if ipadapter_use_instant_style:
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets (TODO)
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Prepare extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
223
diffsynth/pipelines/sdxl_video.py
Normal file
223
diffsynth/pipelines/sdxl_video.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDXLPrompter, KolorsPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .sdxl_image import SDXLImagePipeline
|
||||
from .dancer import lets_dance_xl
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class SDXLVideoPipeline(SDXLImagePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
||||
self.prompter = SDXLPrompter()
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.text_encoder_kolors: ChatGLMModel = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
self.motion_modules: SDXLMotionModel = None
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
||||
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sdxl_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
# ControlNets (TODO)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
||||
|
||||
# Motion Modules
|
||||
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
|
||||
if self.motion_modules is None:
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||
|
||||
# Kolors
|
||||
if self.text_encoder_kolors is not None:
|
||||
print("Switch to Kolors. The prompter will be replaced.")
|
||||
self.prompter = KolorsPrompter()
|
||||
self.prompter.fetch_models(self.text_encoder_kolors)
|
||||
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
|
||||
if self.motion_modules is None:
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
||||
else:
|
||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDXLVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents.append(latent.cpu())
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_use_instant_style=False,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters, batch size ...
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_video(input_frames, **tiler_kwargs)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
latents = latents.to(self.device) # will be deleted for supporting long videos
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
if ipadapter_use_instant_style:
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Prepare extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||
rendered_frames = self.decode_video(rendered_frames)
|
||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||
target_latents = self.encode_video(rendered_frames)
|
||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
output_frames = self.decode_video(latents, **tiler_kwargs)
|
||||
|
||||
# Post-process
|
||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||
|
||||
return output_frames
|
||||
@@ -1,356 +0,0 @@
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from ..data import VideoData, save_frames, save_video
|
||||
from .dancer import lets_dance
|
||||
from ..processors.sequencial_processor import SequencialProcessor
|
||||
from typing import List
|
||||
import torch, os, json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
num_frames = sample.shape[0]
|
||||
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
||||
|
||||
for batch_id in range(0, num_frames, animatediff_stride):
|
||||
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
||||
|
||||
# process this batch
|
||||
hidden_states_batch = lets_dance(
|
||||
unet, motion_modules, controlnet,
|
||||
sample[batch_id: batch_id_].to(device),
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_].to(device),
|
||||
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=device, vram_limit_level=vram_limit_level
|
||||
).cpu()
|
||||
|
||||
# update hidden_states
|
||||
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
||||
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
||||
hidden_states, num = hidden_states_output[i]
|
||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||
hidden_states_output[i] = (hidden_states, num + bias)
|
||||
|
||||
if batch_id_ == num_frames:
|
||||
break
|
||||
|
||||
# output
|
||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||
self.prompter = SDPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
self.vae_decoder: SDVAEDecoder = None
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.motion_modules: SDMotionModel = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||
if "motion_modules" in model_manager.model:
|
||||
self.motion_modules = model_manager.motion_modules
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
use_animatediff="motion_modules" in model_manager.model
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||
latents.append(latent)
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
vram_limit_level=0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
|
||||
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
||||
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||
rendered_frames = self.decode_images(rendered_frames)
|
||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||
target_latents = self.encode_images(rendered_frames)
|
||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
output_frames = self.decode_images(latents)
|
||||
|
||||
# Post-process
|
||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||
|
||||
return output_frames
|
||||
|
||||
|
||||
|
||||
class SDVideoPipelineRunner:
|
||||
def __init__(self, in_streamlit=False):
|
||||
self.in_streamlit = in_streamlit
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def load_smoother(self, model_manager, smoother_configs):
|
||||
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
||||
return smoother
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
else:
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
if start_frame_id is None:
|
||||
start_frame_id = 0
|
||||
if end_frame_id is None:
|
||||
end_frame_id = len(video)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||
if "smoother_configs" in config:
|
||||
if self.in_streamlit: st.markdown("Loading smoother ...")
|
||||
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
||||
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
||||
else:
|
||||
smoother = None
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
if self.in_streamlit: st.video(video_file.read())
|
||||
@@ -1,175 +0,0 @@
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance_xl
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
# TODO: SDXL ControlNet
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||
if "ipadapter_xl" in model_manager.model:
|
||||
self.ipadapter = model_manager.ipadapter_xl
|
||||
if "ipadapter_xl_image_encoder" in model_manager.model:
|
||||
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
pipe.fetch_ipadapter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,190 +0,0 @@
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
|
||||
from .dancer import lets_dance_xl
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# TODO: SDXL ControlNet
|
||||
self.motion_modules: SDXLMotionModel = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||
if "motion_modules_xl" in model_manager.model:
|
||||
self.motion_modules = model_manager.motion_modules_xl
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
use_animatediff="motion_modules_xl" in model_manager.model
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||
latents.append(latent)
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
vram_limit_level=0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_images(latents.to(torch.float32))
|
||||
|
||||
return image
|
||||
@@ -1,5 +1,6 @@
|
||||
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
|
||||
from ..schedulers import ContinuousODEScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
@@ -8,13 +9,11 @@ from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class SVDVideoPipeline(torch.nn.Module):
|
||||
class SVDVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = ContinuousODEScheduler()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.image_encoder: SVDImageEncoder = None
|
||||
self.unet: SVDUNet = None
|
||||
@@ -22,32 +21,23 @@ class SVDVideoPipeline(torch.nn.Module):
|
||||
self.vae_decoder: SVDVAEDecoder = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.image_encoder = model_manager.image_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
|
||||
self.unet = model_manager.fetch_model("svd_unet")
|
||||
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
|
||||
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, **kwargs):
|
||||
pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe = SVDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype
|
||||
)
|
||||
pipe.fetch_models(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def encode_image_with_clip(self, image):
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
|
||||
6
diffsynth/prompters/__init__.py
Normal file
6
diffsynth/prompters/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .prompt_refiners import Translator, BeautifulPrompt
|
||||
from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .sd3_prompter import SD3Prompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
from .kolors_prompter import KolorsPrompter
|
||||
57
diffsynth/prompters/base_prompter.py
Normal file
57
diffsynth/prompters/base_prompter.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from ..models.model_manager import ModelManager
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length if max_length is None else max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
|
||||
class BasePrompter:
|
||||
def __init__(self, refiners=[]):
|
||||
self.refiners = refiners
|
||||
|
||||
|
||||
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
|
||||
for refiner_class in refiner_classes:
|
||||
refiner = refiner_class.from_model_manager(model_nameger)
|
||||
self.refiners.append(refiner)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
if isinstance(prompt, list):
|
||||
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
||||
else:
|
||||
for refiner in self.refiners:
|
||||
prompt = refiner(prompt, positive=positive)
|
||||
return prompt
|
||||
@@ -1,19 +1,34 @@
|
||||
from .utils import Prompter
|
||||
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
|
||||
import warnings
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from transformers import BertTokenizer, AutoTokenizer
|
||||
import warnings, os
|
||||
|
||||
|
||||
class HunyuanDiTPrompter(Prompter):
|
||||
class HunyuanDiTPrompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/hunyuan_dit/tokenizer",
|
||||
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
|
||||
tokenizer_path=None,
|
||||
tokenizer_t5_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||
if tokenizer_t5_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
|
||||
super().__init__()
|
||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
|
||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder_t5 = text_encoder_t5
|
||||
|
||||
|
||||
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
@@ -37,8 +52,6 @@ class HunyuanDiTPrompter(Prompter):
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: BertModel,
|
||||
text_encoder_t5: T5EncoderModel,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
@@ -48,9 +61,9 @@ class HunyuanDiTPrompter(Prompter):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
||||
|
||||
# T5
|
||||
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
||||
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
||||
|
||||
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
|
||||
353
diffsynth/prompters/kolors_prompter.py
Normal file
353
diffsynth/prompters/kolors_prompter.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.model_manager import ModelManager
|
||||
import json, os, re
|
||||
from typing import List, Optional, Union, Dict
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.utils import PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
# reload tokenizer
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
self.bos_id: int = self.sp_model.bos_id()
|
||||
self.eos_id: int = self.sp_model.eos_id()
|
||||
self.pad_id: int = self.sp_model.unk_id()
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
self.special_tokens = {}
|
||||
self.index_special_tokens = {}
|
||||
for token in special_tokens:
|
||||
self.special_tokens[token] = self.n_words
|
||||
self.index_special_tokens[self.n_words] = token
|
||||
self.n_words += 1
|
||||
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
||||
|
||||
def tokenize(self, s: str, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
last_index = 0
|
||||
t = []
|
||||
for match in re.finditer(self.role_special_token_expression, s):
|
||||
if last_index < match.start():
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
||||
t.append(s[match.start():match.end()])
|
||||
last_index = match.end()
|
||||
if last_index < len(s):
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
||||
return t
|
||||
else:
|
||||
return self.sp_model.EncodeAsPieces(s)
|
||||
|
||||
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||
assert type(s) is str
|
||||
t = self.sp_model.encode(s)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
text, buffer = "", []
|
||||
for token in t:
|
||||
if token in self.index_special_tokens:
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
buffer = []
|
||||
text += self.index_special_tokens[token]
|
||||
else:
|
||||
buffer.append(token)
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self.sp_model.DecodePieces(tokens)
|
||||
return text
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.index_special_tokens:
|
||||
return self.index_special_tokens[index]
|
||||
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||
return ""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
||||
**kwargs):
|
||||
self.name = "GLMTokenizer"
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.tokenizer = SPTokenizer(vocab_file)
|
||||
self.special_tokens = {
|
||||
"<bos>": self.tokenizer.bos_id,
|
||||
"<eos>": self.tokenizer.eos_id,
|
||||
"<pad>": self.tokenizer.pad_id
|
||||
}
|
||||
self.encode_special_tokens = encode_special_tokens
|
||||
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
encode_special_tokens=encode_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||
return self.tokenizer.special_tokens[token]
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_words
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.tokenizer.convert_token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.tokenizer.convert_id_to_token(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.tokenizer.decode_tokens(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def get_prefix_tokens(self):
|
||||
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
||||
return prefix_tokens
|
||||
|
||||
def build_single_message(self, role, metadata, message):
|
||||
assert role in ["system", "user", "assistant", "observation"], role
|
||||
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||
message_tokens = self.tokenizer.encode(message)
|
||||
tokens = role_tokens + message_tokens
|
||||
return tokens
|
||||
|
||||
def build_chat_input(self, query, history=None, role="user"):
|
||||
if history is None:
|
||||
history = []
|
||||
input_ids = []
|
||||
for item in history:
|
||||
content = item["content"]
|
||||
if item["role"] == "system" and "tools" in item:
|
||||
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||
input_ids.extend(self.build_single_message(role, "", query))
|
||||
input_ids.extend([self.get_command("<|assistant|>")])
|
||||
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
prefix_tokens = self.get_prefix_tokens()
|
||||
token_ids_0 = prefix_tokens + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * seq_length
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = list(range(seq_length))
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
|
||||
class KolorsPrompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
|
||||
self.text_encoder: ChatGLMModel = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: ChatGLMModel = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
|
||||
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
output = text_encoder(
|
||||
input_ids=text_inputs['input_ids'] ,
|
||||
attention_mask=text_inputs['attention_mask'],
|
||||
position_ids=text_inputs['position_ids'],
|
||||
output_hidden_states=True
|
||||
)
|
||||
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
|
||||
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
|
||||
|
||||
return pooled_prompt_emb, prompt_emb
|
||||
77
diffsynth/prompters/prompt_refiners.py
Normal file
77
diffsynth/prompters/prompt_refiners.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.model_manager import ModelManager
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class BeautifulPrompt(torch.nn.Module):
|
||||
def __init__(self, tokenizer_path=None, model=None, template=""):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = template
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_nameger: ModelManager):
|
||||
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
|
||||
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
if model_path.endswith("v2"):
|
||||
template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
beautiful_prompt = BeautifulPrompt(
|
||||
tokenizer_path=model_path,
|
||||
model=model,
|
||||
template=template
|
||||
)
|
||||
return beautiful_prompt
|
||||
|
||||
|
||||
def __call__(self, raw_prompt, positive=True, **kwargs):
|
||||
if positive:
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
|
||||
return prompt
|
||||
else:
|
||||
return raw_prompt
|
||||
|
||||
|
||||
|
||||
class Translator(torch.nn.Module):
|
||||
def __init__(self, tokenizer_path=None, model=None):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_nameger: ModelManager):
|
||||
model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
|
||||
translator = Translator(tokenizer_path=model_path, model=model)
|
||||
return translator
|
||||
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
print(f"Your prompt is translated: {prompt}")
|
||||
return prompt
|
||||
92
diffsynth/prompters/sd3_prompter.py
Normal file
92
diffsynth/prompters/sd3_prompter.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
import os, torch
|
||||
|
||||
|
||||
class SD3Prompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
tokenizer_2_path=None,
|
||||
tokenizer_3_path=None
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
|
||||
if tokenizer_3_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: SD3TextEncoder2 = None
|
||||
self.text_encoder_3: SD3TextEncoder3 = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
self.text_encoder_3 = text_encoder_3
|
||||
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids.to(device)
|
||||
pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
|
||||
return pooled_prompt_emb, prompt_emb
|
||||
|
||||
|
||||
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
).input_ids.to(device)
|
||||
prompt_emb = text_encoder(input_ids)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
|
||||
|
||||
# T5
|
||||
if self.text_encoder_3 is None:
|
||||
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
else:
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.cat([
|
||||
torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
|
||||
prompt_emb_3
|
||||
], dim=-2)
|
||||
pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
73
diffsynth/prompters/sd_prompter.py
Normal file
73
diffsynth/prompters/sd_prompter.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from .base_prompter import BasePrompter, tokenize_long_prompt
|
||||
from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings
|
||||
from ..models import SDTextEncoder
|
||||
from transformers import CLIPTokenizer
|
||||
import torch, os
|
||||
|
||||
|
||||
|
||||
class SDPrompter(BasePrompter):
|
||||
def __init__(self, tokenizer_path=None):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.textual_inversion_dict = {}
|
||||
self.keyword_dict = {}
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: SDTextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
|
||||
def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
|
||||
dtype = next(iter(text_encoder.parameters())).dtype
|
||||
state_dict = text_encoder.token_embedding.state_dict()
|
||||
token_embeddings = [state_dict["weight"]]
|
||||
for keyword in textual_inversion_dict:
|
||||
_, embeddings = textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["weight"] = token_embeddings
|
||||
text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
|
||||
text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
|
||||
text_encoder.token_embedding.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
|
||||
def load_textual_inversions(self, model_paths):
|
||||
for model_path in model_paths:
|
||||
keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
|
||||
state_dict = load_state_dict(model_path)
|
||||
|
||||
# Search for embeddings
|
||||
for embeddings in search_for_embeddings(state_dict):
|
||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
|
||||
self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
|
||||
self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
print(f"Textual inversion {keyword} is enabled.")
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
61
diffsynth/prompters/sdxl_prompter.py
Normal file
61
diffsynth/prompters/sdxl_prompter.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from .base_prompter import BasePrompter, tokenize_long_prompt
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from transformers import CLIPTokenizer
|
||||
import torch, os
|
||||
|
||||
|
||||
|
||||
class SDXLPrompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None,
|
||||
tokenizer_2_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
|
||||
max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
|
||||
prompt_emb_1 = prompt_emb_1[: max_batch_size]
|
||||
prompt_emb_2 = prompt_emb_2[: max_batch_size]
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
@@ -1,3 +0,0 @@
|
||||
from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
@@ -1,17 +0,0 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
@@ -1,43 +0,0 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
@@ -1,123 +0,0 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import ModelManager
|
||||
import os
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
@@ -1,2 +1,3 @@
|
||||
from .ddim import EnhancedDDIMScheduler
|
||||
from .continuous_ode import ContinuousODEScheduler
|
||||
from .flow_match import FlowMatchScheduler
|
||||
|
||||
@@ -22,10 +22,10 @@ class EnhancedDDIMScheduler():
|
||||
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||
if num_inference_steps == 1:
|
||||
self.timesteps = [max_timestep]
|
||||
self.timesteps = torch.Tensor([max_timestep])
|
||||
else:
|
||||
step_length = max_timestep / (num_inference_steps - 1)
|
||||
self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
|
||||
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
||||
|
||||
|
||||
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||
@@ -43,31 +43,37 @@ class EnhancedDDIMScheduler():
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
timestep_id = self.timesteps.index(timestep)
|
||||
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
alpha_prod_t_prev = 1.0
|
||||
else:
|
||||
timestep_prev = self.timesteps[timestep_id + 1]
|
||||
timestep_prev = int(self.timesteps[timestep_id + 1])
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||
|
||||
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
if self.prediction_type == "epsilon":
|
||||
return noise
|
||||
else:
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
|
||||
51
diffsynth/schedulers/flow_match.py
Normal file
51
diffsynth/schedulers/flow_match.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
# This scheduler doesn't support this function.
|
||||
pass
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
0
diffsynth/tokenizer_configs/__init__.py
Normal file
0
diffsynth/tokenizer_configs/__init__.py
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
Normal file
Binary file not shown.
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name_or_path": "THUDM/chatglm3-6b-base",
|
||||
"remove_space": false,
|
||||
"do_lower_case": false,
|
||||
"tokenizer_class": "ChatGLMTokenizer",
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenization_chatglm.ChatGLMTokenizer",
|
||||
null
|
||||
]
|
||||
}
|
||||
}
|
||||
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
Normal file
Binary file not shown.
48895
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt
Normal file
48895
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
49410
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json
Normal file
49410
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
48895
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt
Normal file
48895
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
49410
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json
Normal file
49410
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,125 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
Binary file not shown.
129428
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json
Normal file
129428
diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,940 @@
|
||||
{
|
||||
"add_prefix_space": true,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32000": {
|
||||
"content": "<extra_id_99>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<extra_id_98>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<extra_id_97>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32003": {
|
||||
"content": "<extra_id_96>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32004": {
|
||||
"content": "<extra_id_95>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32005": {
|
||||
"content": "<extra_id_94>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32006": {
|
||||
"content": "<extra_id_93>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32007": {
|
||||
"content": "<extra_id_92>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32008": {
|
||||
"content": "<extra_id_91>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32009": {
|
||||
"content": "<extra_id_90>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32010": {
|
||||
"content": "<extra_id_89>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32011": {
|
||||
"content": "<extra_id_88>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32012": {
|
||||
"content": "<extra_id_87>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32013": {
|
||||
"content": "<extra_id_86>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32014": {
|
||||
"content": "<extra_id_85>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32015": {
|
||||
"content": "<extra_id_84>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32016": {
|
||||
"content": "<extra_id_83>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32017": {
|
||||
"content": "<extra_id_82>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32018": {
|
||||
"content": "<extra_id_81>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32019": {
|
||||
"content": "<extra_id_80>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32020": {
|
||||
"content": "<extra_id_79>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32021": {
|
||||
"content": "<extra_id_78>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32022": {
|
||||
"content": "<extra_id_77>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32023": {
|
||||
"content": "<extra_id_76>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32024": {
|
||||
"content": "<extra_id_75>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32025": {
|
||||
"content": "<extra_id_74>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32026": {
|
||||
"content": "<extra_id_73>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32027": {
|
||||
"content": "<extra_id_72>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32028": {
|
||||
"content": "<extra_id_71>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32029": {
|
||||
"content": "<extra_id_70>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32030": {
|
||||
"content": "<extra_id_69>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32031": {
|
||||
"content": "<extra_id_68>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32032": {
|
||||
"content": "<extra_id_67>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32033": {
|
||||
"content": "<extra_id_66>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32034": {
|
||||
"content": "<extra_id_65>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32035": {
|
||||
"content": "<extra_id_64>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32036": {
|
||||
"content": "<extra_id_63>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32037": {
|
||||
"content": "<extra_id_62>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32038": {
|
||||
"content": "<extra_id_61>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32039": {
|
||||
"content": "<extra_id_60>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32040": {
|
||||
"content": "<extra_id_59>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32041": {
|
||||
"content": "<extra_id_58>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32042": {
|
||||
"content": "<extra_id_57>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32043": {
|
||||
"content": "<extra_id_56>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32044": {
|
||||
"content": "<extra_id_55>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32045": {
|
||||
"content": "<extra_id_54>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32046": {
|
||||
"content": "<extra_id_53>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32047": {
|
||||
"content": "<extra_id_52>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32048": {
|
||||
"content": "<extra_id_51>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32049": {
|
||||
"content": "<extra_id_50>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32050": {
|
||||
"content": "<extra_id_49>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32051": {
|
||||
"content": "<extra_id_48>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32052": {
|
||||
"content": "<extra_id_47>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32053": {
|
||||
"content": "<extra_id_46>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32054": {
|
||||
"content": "<extra_id_45>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32055": {
|
||||
"content": "<extra_id_44>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32056": {
|
||||
"content": "<extra_id_43>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32057": {
|
||||
"content": "<extra_id_42>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32058": {
|
||||
"content": "<extra_id_41>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32059": {
|
||||
"content": "<extra_id_40>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32060": {
|
||||
"content": "<extra_id_39>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32061": {
|
||||
"content": "<extra_id_38>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32062": {
|
||||
"content": "<extra_id_37>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32063": {
|
||||
"content": "<extra_id_36>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32064": {
|
||||
"content": "<extra_id_35>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32065": {
|
||||
"content": "<extra_id_34>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32066": {
|
||||
"content": "<extra_id_33>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32067": {
|
||||
"content": "<extra_id_32>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32068": {
|
||||
"content": "<extra_id_31>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32069": {
|
||||
"content": "<extra_id_30>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32070": {
|
||||
"content": "<extra_id_29>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32071": {
|
||||
"content": "<extra_id_28>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32072": {
|
||||
"content": "<extra_id_27>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32073": {
|
||||
"content": "<extra_id_26>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32074": {
|
||||
"content": "<extra_id_25>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32075": {
|
||||
"content": "<extra_id_24>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32076": {
|
||||
"content": "<extra_id_23>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32077": {
|
||||
"content": "<extra_id_22>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32078": {
|
||||
"content": "<extra_id_21>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32079": {
|
||||
"content": "<extra_id_20>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32080": {
|
||||
"content": "<extra_id_19>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32081": {
|
||||
"content": "<extra_id_18>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32082": {
|
||||
"content": "<extra_id_17>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32083": {
|
||||
"content": "<extra_id_16>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32084": {
|
||||
"content": "<extra_id_15>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32085": {
|
||||
"content": "<extra_id_14>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32086": {
|
||||
"content": "<extra_id_13>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32087": {
|
||||
"content": "<extra_id_12>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32088": {
|
||||
"content": "<extra_id_11>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32089": {
|
||||
"content": "<extra_id_10>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32090": {
|
||||
"content": "<extra_id_9>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32091": {
|
||||
"content": "<extra_id_8>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32092": {
|
||||
"content": "<extra_id_7>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32093": {
|
||||
"content": "<extra_id_6>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32094": {
|
||||
"content": "<extra_id_5>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32095": {
|
||||
"content": "<extra_id_4>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32096": {
|
||||
"content": "<extra_id_3>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32097": {
|
||||
"content": "<extra_id_2>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32098": {
|
||||
"content": "<extra_id_1>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32099": {
|
||||
"content": "<extra_id_0>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "</s>",
|
||||
"extra_ids": 100,
|
||||
"legacy": true,
|
||||
"model_max_length": 512,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
0
diffsynth/trainers/__init__.py
Normal file
0
diffsynth/trainers/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user