Compare commits

..

67 Commits
zero3 ... doc

Author SHA1 Message Date
Qianyi Zhao
b13963382c Update FLUX.md 2024-11-08 00:44:06 -06:00
Qianyi Zhao
e11cf9e318 Update FLUX.md 2024-11-08 00:34:29 -06:00
Qianyi Zhao
4f93be2f5a Update StableDiffusion3.md 2024-11-08 00:11:51 -06:00
Qing112
c9c6be2201 Add ModelQuantization.md
Update Kolors.md
2024-11-07 19:20:19 +08:00
Qianyi Zhao
2b07df1c7a Update MultiControlnet.md 2024-11-06 00:07:50 -06:00
Qianyi Zhao
31161317e3 Update MultiControlnet.md 2024-11-05 23:52:12 -06:00
Qianyi Zhao
787813ab98 Update MultiControlnet.md 2024-11-05 23:51:29 -06:00
Qianyi Zhao
9fef3ee281 Update overview.md 2024-11-05 23:45:09 -06:00
Qianyi Zhao
b205513041 Update MultiControlnet.md 2024-11-05 23:40:24 -06:00
Qianyi Zhao
97dd398f17 Update overview.md 2024-11-05 23:33:24 -06:00
Qianyi Zhao
85ac23c0c3 Update introduction.md 2024-11-05 09:20:40 -06:00
Qianyi Zhao
b4073d2643 Update introduction.md 2024-11-05 09:12:28 -06:00
Qianyi Zhao
9583f16a43 Update introduction.md 2024-11-05 09:08:16 -06:00
Qianyi Zhao
b0633ac8bb Update introduction.md 2024-11-05 05:38:15 -06:00
Qianyi Zhao
9166a6742c Update introduction.md 2024-11-05 05:31:54 -06:00
Qianyi Zhao
10cfa6d711 Update introduction.md 2024-11-05 05:29:35 -06:00
Qianyi Zhao
b78ffbe09e Update introduction.md 2024-11-05 05:25:30 -06:00
Qianyi Zhao
64af33fe33 Update introduction.md 2024-11-05 03:45:57 -06:00
Qianyi Zhao
1180f450ca Update introduction.md 2024-11-05 03:45:33 -06:00
Qianyi Zhao
99726e02de Update MultiControlnet.md 2024-11-05 03:45:15 -06:00
Qianyi Zhao
e0c09ed53d Update overview.md 2024-11-05 03:44:09 -06:00
Qianyi Zhao
250ebf5c72 Update train_flux_lora.md 2024-11-05 03:40:33 -06:00
Artiprocher
47a2f86f7b update docs 2024-11-04 17:05:58 +08:00
Artiprocher
e2d9710d86 update docs 2024-11-04 16:23:03 +08:00
Artiprocher
384ea0dc69 update docs 2024-11-04 16:09:32 +08:00
Artiprocher
e0ef3eea60 update docs 2024-11-04 15:54:58 +08:00
Artiprocher
ac67acd235 update docs 2024-11-04 15:49:41 +08:00
Artiprocher
fe68a3d1bb update docs 2024-11-04 14:09:36 +08:00
Zhongjie Duan
deff4512f7 Merge pull request #256 from mi804/doc
AnimateDiff doc
2024-11-01 17:35:54 +08:00
张鸿
29efb1c828 AnimateDiff doc 2024-11-01 17:26:06 +08:00
Artiprocher
e833a31909 update doc index 2024-11-01 15:20:09 +08:00
Zhongjie Duan
b626d2aad7 Doc patch (#255)
* update multi-controlnet doc

* update multi-controlnet doc
2024-11-01 15:18:18 +08:00
Zhongjie Duan
0bc89f973e Merge pull request #254 from mi804/doc
IP-Adapter doc
2024-11-01 14:49:45 +08:00
张鸿
3eeaa1cd32 IP-Adapter doc 2024-11-01 14:44:12 +08:00
Artiprocher
4b25495921 update cogvideo doc 2024-11-01 14:40:31 +08:00
Artiprocher
ac2b187b9f update svd doc 2024-11-01 14:01:31 +08:00
Qianyi Zhao
eece711313 Update en index.rst
英文caption随便写了下,需要重新写
2024-10-28 06:09:08 -05:00
tc2000731
4e1cea64ad update HunyuanDiT.md Kolors.md 2024-10-25 17:32:29 +08:00
Qianyi Zhao
1a2ce26d37 Update StableDiffusionXL.md 2024-10-25 01:18:11 -05:00
Qianyi Zhao
b17a0297a2 Update StableDiffusion.md 2024-10-25 01:17:47 -05:00
Zhongjie Duan
9af2d08a33 Update index.rst 2024-10-23 10:12:05 +08:00
Zhongjie Duan
b2df73d033 Merge pull request #247 from mi804/doc
support controlnet cn doc
2024-10-22 14:53:27 +08:00
Zhongjie Duan
3514eba956 Merge pull request #246 from yrk111222/doc
Doc
2024-10-22 14:52:34 +08:00
mi804
fb0e5d1f38 support controlnet cn doc 2024-10-22 11:38:17 +08:00
yrk111222
b43cc35dd9 Create .readthedocs.yaml 2024-10-22 11:30:17 +08:00
yrk111222
34ca18a217 Create .readthedocs.yaml 2024-10-22 11:29:15 +08:00
Zhongjie Duan
550d780cd6 Merge pull request #245 from yrk111222/doc
Doc
2024-10-22 10:18:10 +08:00
Zhongjie Duan
ded2882e87 Merge branch 'doc' into doc 2024-10-22 10:17:18 +08:00
yrk111222
f6e676cdf9 Add files via upload
再改一次
2024-10-22 09:56:03 +08:00
yrk111222
157ba2e426 Delete docs directory 2024-10-22 09:54:59 +08:00
yrk111222
1a004ffe81 Add files via upload
上一次上传到docs文件夹中了,修改一下
2024-10-22 09:52:12 +08:00
yrk111222
70c4ff4121 Add files via upload
第三遍检查,一些明显的错误,像是格式之类的。
2024-10-22 09:45:58 +08:00
yrk111222
883d26abb4 Add files via upload
第一版翻译完成,保留了getStart目录,有一些名词还是需要重新检查
2024-10-18 18:02:52 +08:00
Artiprocher
105d4ffbc2 update docs 2024-10-18 15:38:12 +08:00
yrk111222
24b78148b8 Add files via upload
电脑更换,完成到D:\translate\DiffSynth-Studio\docs\source_en\finetune,该写第四个文档
2024-10-18 11:36:48 +08:00
Artiprocher
793062e141 update docs 2024-10-15 18:17:26 +08:00
Artiprocher
98f07f2435 update docs 2024-10-15 17:38:40 +08:00
Artiprocher
ca4b9c8bf4 update docs 2024-10-15 17:38:25 +08:00
Artiprocher
a2ab597eb0 update docs 2024-10-09 09:55:29 +08:00
Artiprocher
950fb486d6 update docs 2024-10-09 09:53:13 +08:00
Artiprocher
28b4a5313e update docs 2024-10-09 09:39:12 +08:00
Artiprocher
d9d37568a7 update docs 2024-10-08 21:20:36 +08:00
Artiprocher
55f1a10255 update docs 2024-10-08 21:18:35 +08:00
Artiprocher
677ecbf1d2 update docs 2024-09-19 09:47:54 +08:00
Artiprocher
5a06ac5e31 update docs 2024-09-11 21:07:01 +08:00
Artiprocher
41f58e2d41 update docs 2024-09-11 16:37:46 +08:00
Qianyi Zhao
7f6e35fe35 add docs (#201)
add docs.
2024-09-10 18:46:24 +08:00
907 changed files with 844648 additions and 53287 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

View File

@@ -20,9 +20,9 @@ jobs:
with:
python-version: '3.10'
- name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt
run: pip install wheel && pip install -r requirements.txt
- name: Build DiffSynth
run: python -m build
run: python setup.py sdist bdist_wheel
- name: Publish package to PyPI
run: |
pip install twine

175
.gitignore vendored
View File

@@ -1,175 +0,0 @@
/data
/models
/scripts
/diffusers
*.pkl
*.safetensors
*.pth
*.ckpt
*.pt
*.bin
*.DS_Store
*.msc
*.mv
log*.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

845
README.md
View File

@@ -1,782 +1,209 @@
# DiffSynth-Studio
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
# DiffSynth Studio
[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[切换到中文版](./README_zh.md)
<p align="center">
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>
## Introduction
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
DiffSynth currently includes two open-source projects:
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, targeting academia, and providing cutting-edge model capability support.
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, targeting industry, and providing higher computational performance and more stable features.
Until now, DiffSynth Studio has supported the following models:
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core engines of the ModelScope AIGC zone. Welcome to experience our carefully crafted productized features:
* [CogVideo](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
* [ESRGAN](https://github.com/xinntao/ESRGAN)
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
* [AnimateDiff](https://github.com/guoyww/animatediff/)
* [ControlNet](https://github.com/lllyasviel/ControlNet)
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
## News
> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
## Update History
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
- [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
- New model support
- Z-Image Turbo: [Model](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo), [Documentation](/docs/en/Model_Details/Z-Image.md), [Code](/examples/z_image/)
- FLUX.2-dev: [Model](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev), [Documentation](/docs/en/Model_Details/FLUX2.md), [Code](/examples/flux2/)
- Training framework upgrade
- [Split Training](/docs/zh/Training/Split_Training.md): Supports automatically splitting the training process into two stages: data processing and training (even for training ControlNet or any other model). Computations that do not require gradient backpropagation, such as text encoding and VAE encoding, are performed during the data processing stage, while other computations are handled during the training stage. Faster speed, less VRAM requirement.
- [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
- [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
<details>
<summary>More</summary>
- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
- **October 27, 2025** Supported the [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, adding another member to the Wan model ecosystem.
- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) released! This model was jointly developed and open-sourced by us and Taobao Experience Design Team. Built upon Qwen-Image, the model is specifically designed for e-commerce poster scenarios, supporting precise partition layout control. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
- **September 9, 2025** Our training framework supports various training modes. Currently adapted for Qwen-Image, in addition to the standard SFT training mode, Direct Distill is now supported. Please refer to [our sample code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support more comprehensive model training functions.
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model. See [./examples/wanvideo/](./examples/wanvideo/).
- **August 21, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) released! Compared to the V1 version, the training dataset has been changed to [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), so the generated images better conform to Qwen-Image's own image distribution and style. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
- **August 21, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structural control LoRA model, adopting the In Context technical route, supporting multiple categories of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
- **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) dataset. This is an image dataset generated using the Qwen-Image model, containing 160,000 `1024 x 1024` images. It includes general, English text rendering, and Chinese text rendering subsets. We provide annotations for image descriptions, entities, and structural control images for each image. Developers can use this dataset to train Qwen-Image models' ControlNet and EliGen models. We aim to promote technological development through open-sourcing!
- **August 13, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
- **August 12, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
- **August 11, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) for Qwen-Image, following the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure has been modified to LoRA, thus being better compatible with other open-source ecosystem models.
- **August 7, 2025** We open-sourced the entity control LoRA model [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) for Qwen-Image. Qwen-Image-EliGen can achieve entity-level controlled text-to-image generation. Technical details can be found in [the paper](https://arxiv.org/abs/2501.01097). Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
- **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
- **July 28, 2025** Wan 2.2 open-sourced. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, and full training. For more details, please refer to [./examples/wanvideo/](./examples/wanvideo/).
- **July 11, 2025** We propose Nexus-Gen, a unified framework that combines the language reasoning capabilities of Large Language Models (LLMs) with the image generation capabilities of diffusion models. This framework supports seamless image understanding, generation, and editing tasks.
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- GitHub Repository: https://github.com/modelscope/Nexus-Gen
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
- **June 15, 2025** ModelScope's official evaluation framework [EvalScope](https://github.com/modelscope/evalscope) now supports text-to-image generation evaluation. Please refer to the [best practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide to try it out.
- **March 25, 2025** Our new open-source project [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) is now open-sourced! Focused on stable model deployment, targeting industry, providing better engineering support, higher computational performance, and more stable features.
- **March 31, 2025** We support InfiniteYou, a face feature preservation method for FLUX. More details can be found in [./examples/InfiniteYou/](./examples/InfiniteYou/).
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of Tencent's open-source HunyuanVideo. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
- **February 25, 2025** We support Wan-Video, a series of state-of-the-art video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! Advanced video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
- **December 31, 2024** We propose EliGen, a new framework for entity-level controlled text-to-image generation, supplemented with an inpainting fusion pipeline, extending its capabilities to image inpainting tasks. EliGen can seamlessly integrate existing community models such as IP-Adapter and In-Context LoRA, enhancing their versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
- **December 19, 2024** We implemented advanced VRAM management for HunyuanVideo, enabling video generation with resolutions of 129x720x1280 on 24GB VRAM or 129x512x384 on just 6GB VRAM. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
- **December 18, 2024** We propose ArtAug, a method to improve text-to-image models through synthesis-understanding interaction. We trained an ArtAug enhancement module for FLUX.1-dev in LoRA format. This model incorporates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, thereby improving the quality of generated images.
- Paper: https://arxiv.org/abs/2412.12888
- Example: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (coming soon)
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models and can be freely combined, even if their structures are different. Additionally, ControlNet models are compatible with high-resolution optimization and partition control technologies, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
- **October 8, 2024** We released extended LoRAs based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
- **August 22, 2024** This project now supports CogVideoX-5B. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including:
- Text-to-video
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
- Text to video
- Video editing
- Self super-resolution
- Self-upscaling
- Video interpolation
- **August 22, 2024** We implemented an interesting brush feature that supports all text-to-image models. Now you can create stunning images with the assistance of AI using the brush!
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
- Use it in our [WebUI](#usage-in-webui).
- **August 21, 2024** DiffSynth-Studio now supports FLUX.
- Enable CFG and high-resolution inpainting to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and other addon models will be released soon.
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and additional models will be available soon.
- **June 21, 2024** We propose ExVideo, a post-training fine-tuning technique aimed at enhancing the capabilities of video generation models. We extended Stable Video Diffusion to achieve long video generation of up to 128 frames.
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Source code has been released in this repository. See [`examples/ExVideo`](./examples/ExVideo/).
- Model has been released at [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
- Technical report has been released at [arXiv](https://arxiv.org/abs/2406.14130).
- You can try ExVideo in this [demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
- **June 13, 2024** DiffSynth Studio has migrated to ModelScope. The development team has also transitioned from "me" to "us". Of course, I will still participate in subsequent development and maintenance work.
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
- **January 29, 2024** We propose Diffutoon, an excellent cartoon coloring solution.
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- Source code has been released in this project.
- Technical report (IJCAI 2024) has been released at [arXiv](https://arxiv.org/abs/2401.16224).
- The source codes are released in this project.
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
- **December 8, 2023** We decided to initiate a new project aimed at unleashing the potential of diffusion models, especially in video synthesis. The development work of this project officially began.
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
- **November 15, 2023** We propose FastBlend, a powerful video deflickering algorithm.
- sd-webui extension has been released at [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
- Demonstration videos have been showcased on Bilibili, including three tasks:
- [Video Deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
- [Video Interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
- [Image-Driven Video Rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
- Technical report has been released at [arXiv](https://arxiv.org/abs/2311.09265).
- Unofficial ComfyUI extensions developed by other users have been released at [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
- Demo videos are shown on Bilibili, including three tasks.
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
- **October 1, 2023** We released an early version of the project named FastSDXL. This was an initial attempt to build a diffusion engine.
- Source code has been released at [GitHub](https://github.com/Artiprocher/FastSDXL).
- FastSDXL includes a trainable OLSS scheduler to improve efficiency.
- The original repository of OLSS is located [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
- Technical report (CIKM 2023) has been released at [arXiv](https://arxiv.org/abs/2305.14677).
- Demonstration video has been released at [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
- Since OLSS requires additional training, we did not implement it in this project.
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
- Since OLSS requires additional training, we don't implement it in this project.
- **August 29, 2023** We propose DiffSynth, a video synthesis framework.
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
- Source code has been released at [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
- Technical report (ECML PKDD 2024) has been released at [arXiv](https://arxiv.org/abs/2308.03463).
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
</details>
## Installation
Install from source (recommended):
Install from source code (recommended):
```
git clone https://github.com/modelscope/DiffSynth-Studio.git
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more installation methods and instructions for non-NVIDIA GPUs, please refer to the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md).
Or install from pypi:
</details>
```
pip install diffsynth
```
## Basic Framework
## Usage (in Python code)
DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
The Python examples are in [`examples`](./examples/). We provide an overview here.
<details>
<summary>Environment Variable Configuration</summary>
### Download Models
> Before running model inference or training, you can configure settings such as the model download source via [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md).
>
> By default, this project downloads models from ModelScope. For users outside China, you can configure the system to download models from the ModelScope international site as follows:
>
> ```python
> import os
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
> ```
>
> To download models from other sources, please modify the environment variable [DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source).
</details>
### Image Synthesis
![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
#### Z-Image: [/docs/en/Model_Details/Z-Image.md](/docs/en/Model_Details/Z-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model for inference. FP8 quantization significantly degrades image quality, so we do not recommend enabling any quantization for the Z-Image Turbo model. CPU offloading is recommended, and the model can run with as little as 8 GB of GPU memory.
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
```python
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
import torch
from diffsynth import download_models
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg")
download_models(["FLUX.1-dev", "Kolors"])
```
</details>
<details>
<summary>Examples</summary>
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
</details>
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model for inference. VRAM management is enabled, and the framework automatically loads model parameters based on available GPU memory. The model can run with as little as 10 GB of VRAM.
Download your own models.
```python
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
image.save("image.jpg")
# From Modelscope (recommended)
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
# From Huggingface
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
```
</details>
<details>
<summary>Examples</summary>
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details>
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
</details>
<details>
<summary>Model Lineage</summary>
```mermaid
graph LR;
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
Qwen/Qwen-Image-->EliGen-Series;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
Qwen/Qwen-Image-->Distill-Series;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
Qwen/Qwen-Image-->ControlNet-Series;
ControlNet-Series-->Blockwise-ControlNet-Series;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
```
</details>
<details>
<summary>Examples</summary>
Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/qwen_image/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
</details>
#### FLUX.1: [/docs/en/Model_Details/FLUX.md](/docs/en/Model_Details/FLUX.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
```python
import torch
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
vram_config = {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": "cpu",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
],
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
image = pipe(prompt=prompt, seed=0)
image.save("image.jpg")
```
</details>
<details>
<summary>Model Lineage</summary>
```mermaid
graph LR;
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
```
</details>
<details>
<summary>Examples</summary>
Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
</details>
### Video Synthesis
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### Text-to-video using CogVideoX-5B
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
<details>
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
<summary>Quick Start</summary>
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
Running the following code will quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
#### Long Video Synthesis
```python
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video.mp4", fps=15, quality=5)
```
</details>
<details>
<summary>Model Lineage</summary>
```mermaid
graph LR;
Wan-Series-->Wan2.1-Series;
Wan-Series-->Wan2.2-Series;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
```
</details>
<details>
<summary>Examples</summary>
Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
| Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
</details>
## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
<details>
<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
- Paper: [AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models](https://arxiv.org/abs/2508.02151)
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
</details>
<details>
<summary>AutoLoRA: Automated LoRA Retrieval and Fusion</summary>
- Paper: [AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation](https://arxiv.org/abs/2508.02107)
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
</details>
<details>
<summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
- Detailed Page: https://github.com/modelscope/Nexus-Gen
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
</details>
<details>
<summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
- Detailed Page: [./examples/ArtAug/](./examples/ArtAug/)
- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- Online Experience: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|-|-|
|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
</details>
<details>
<summary>EliGen: Precise Image Partition Control</summary>
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|Entity Control Region|Generated Image|
|-|-|
|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
</details>
<details>
<summary>ExVideo: Extended Training for Video Generation Models</summary>
- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)
- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
</details>
#### Toon Shading
<details>
<summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
</details>
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
#### Video Stylization
<details>
<summary>DiffSynth: The Original Version of This Project</summary>
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details>
### Image Synthesis
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|FLUX|Stable Diffusion 3|
|-|-|
|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|Kolors|Hunyuan-DiT|
|-|-|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
|Stable Diffusion|Stable Diffusion XL|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
## Usage (in WebUI)
Create stunning images using the painter, with assistance from AI!
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
**This video is not rendered in real-time.**
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
* `Gradio` version
```
pip install gradio
```
```
python apps/gradio/DiffSynth_Studio.py
```
![20240822102002](https://github.com/user-attachments/assets/59613157-de51-4109-99b3-97cbffd88076)
* `Streamlit` version
```
pip install streamlit streamlit-drawable-canvas
```
```
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
```
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954

View File

@@ -1,784 +0,0 @@
# DiffSynth-Studio
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[Switch to English](./README.md)
## 简介
欢迎来到 Diffusion 模型的魔法世界DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
DiffSynth 目前包括两个开源项目:
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio你可以快速实现这些想法。为此我们为开发者准备了详细的文档我们希望通过这些文档帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
## 更新历史
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog[中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)这个模型可以输入三张图图A、图B、图C模型会自行分析图A到图B的变化并将这样的变化应用到图C生成图D。更多细节请阅读我们的 blog[中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload同时释放内存与显存
- 新模型支持
- Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
- FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
- 训练框架升级
- [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
- [FP8 训练](/docs/zh/Training/FP8_Precision.md)FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
<details>
<summary>更多</summary>
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型Wan 模型生态再添一员。
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image除标准 SFT 训练模式外,已支持 Direct Distill请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
- **2025年8月28日** 我们支持了Wan2.2-S2V一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA因此能够更好地与其他开源生态模型兼容。
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
- **2025年7月11日** 我们提出 Nexus-Gen一个将大语言模型LLM的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- Github 仓库: https://github.com/modelscope/Nexus-Gen
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
- **2025年3月31日** 我们支持 InfiniteYou一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
- **2025年3月13日** 我们支持 HunyuanVideo-I2V即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
- **2025年2月25日** 我们支持 Wan-Video这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
- **2024年12月31日** 我们提出 EliGen一种用于精确实体级别控制的文本到图像生成的新框架并辅以修复融合管道将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA提升其通用性。更多详情请见 [./examples/EntityControl](./examples/EntityControl/)。
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
- **2024年12月18日** 我们提出 ArtAug一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev从而提升了生成图像的质量。
- 论文: https://arxiv.org/abs/2412.12888
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型并且可以自由组合即使它们的结构不同。此外ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
- 文本到视频
- 视频编辑
- 自我超分
- 视频插帧
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
- LoRA、ControlNet 和其他附加模型将很快推出。
- **2024年6月21日** 我们提出 ExVideo一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然我仍会参与后续的开发和维护工作。
- **2024年1月29日** 我们提出 Diffutoon这是一个出色的卡通着色解决方案。
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- 源代码已在此项目中发布。
- 技术报告IJCAI 2024已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
- **2023年11月15日** 我们提出 FastBlend一种强大的视频去闪烁算法。
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
- 演示视频已在 Bilibili 上展示,包含三个任务:
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
- 技术报告CIKM 2023已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
- **2023年8月29日** 我们提出 DiffSynth一个视频合成框架。
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
- 技术报告ECML PKDD 2024已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
</details>
## 安装
从源码安装(推荐):
```
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
</details>
## 基础框架
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
<details>
<summary>环境变量配置</summary>
> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
>
> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
>
> ```python
> import os
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
> ```
>
> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
</details>
### 图像生成模型
![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
#### Z-Image[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload最低 8G 显存即可运行。
```python
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
</details>
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
```python
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details>
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
</details>
<details>
<summary>模型血缘</summary>
```mermaid
graph LR;
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
Qwen/Qwen-Image-->EliGen-Series;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
Qwen/Qwen-Image-->Distill-Series;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
Qwen/Qwen-Image-->ControlNet-Series;
ControlNet-Series-->Blockwise-ControlNet-Series;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
```
</details>
<details>
<summary>示例代码</summary>
Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
</details>
#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
import torch
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
vram_config = {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": "cpu",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
],
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
image = pipe(prompt=prompt, seed=0)
image.save("image.jpg")
```
</details>
<details>
<summary>模型血缘</summary>
```mermaid
graph LR;
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
```
</details>
<details>
<summary>示例代码</summary>
FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
</details>
### 视频生成模型
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video.mp4", fps=15, quality=5)
```
</details>
<details>
<summary>模型血缘</summary>
```mermaid
graph LR;
Wan-Series-->Wan2.1-Series;
Wan-Series-->Wan2.2-Series;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
```
</details>
<details>
<summary>示例代码</summary>
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
</details>
## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
<details>
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
](https://arxiv.org/abs/2508.02151)
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
</details>
<details>
<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
](https://arxiv.org/abs/2508.02107)
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
</details>
<details>
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
- 详细页面https://github.com/modelscope/Nexus-Gen
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
</details>
<details>
<summary>ArtAug: 图像生成模型的美学提升</summary>
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|-|-|
|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
</details>
<details>
<summary>EliGen: 精准的图像分区控制</summary>
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|实体控制区域|生成图像|
|-|-|
|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
</details>
<details>
<summary>ExVideo: 视频生成模型的扩展训练</summary>
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
</details>
<details>
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
</details>
<details>
<summary>DiffSynth: 本项目的初代版本</summary>
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details>

View File

@@ -0,0 +1,252 @@
import gradio as gr
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
import os, torch
from PIL import Image
import numpy as np
config = {
"model_config": {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
"height": 512,
"width": 512,
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
def load_model_list(model_type):
if model_type is None:
return []
folder = config["model_config"][model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def load_model(model_type, model_path):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
while len(model_dict) + 1 > config["max_num_model_cache"]:
key = next(iter(model_dict.keys()))
model_manager_to_release, _ = model_dict[key]
model_manager_to_release.to("cpu")
del model_dict[key]
torch.cuda.empty_cache()
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
model_dict = {}
with gr.Blocks() as app:
gr.Markdown("# DiffSynth-Studio Painter")
with gr.Row():
with gr.Column(scale=382, min_width=100):
with gr.Accordion(label="Model"):
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
def model_type_to_model_path(model_type):
return gr.Dropdown(choices=load_model_list(model_type))
with gr.Accordion(label="Prompt"):
prompt = gr.Textbox(label="Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
with gr.Accordion(label="Image"):
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Column():
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
triggers=model_path.change
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
load_model(model_type, model_path)
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config["model_config"][model_type]["default_parameters"].get("height", height)
width = config["model_config"][model_type]["default_parameters"].get("width", width)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Painter"):
enable_local_prompt_list = []
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
label="Painter", key=f"canvas_{painter_layer_id}")
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
enable_local_prompt_list.append(enable_local_prompt)
local_prompt_list.append(local_prompt)
mask_scale_list.append(mask_scale)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
output_to_input_button = gr.Button(value="Set as input image")
painter_background = gr.State(None)
input_background = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
):
if enable_local_prompt:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
mask_scales.append(mask_scale)
input_params.update({
"local_prompts": local_prompts,
"masks": masks,
"mask_scales": mask_scales,
})
torch.manual_seed(seed)
image = pipe(**input_params)
return image
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(output_image, *canvas_list):
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = output_image.resize((w, h))
return tuple(canvas_list)
app.launch()

View File

@@ -0,0 +1,15 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Diasble virtual VRAM on windows system
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
st.markdown("""
# DiffSynth Studio
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
Welcome to DiffSynth Studio.
""")

View File

@@ -0,0 +1,362 @@
import torch, os, io, json, time
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
from diffsynth.data.video import crop_and_resize
config = {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"fixed_parameters": {
"height": 1024,
"width": 1024,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"fixed_parameters": {
"cfg_scale": 1.0,
}
}
}
def load_model_list(model_type):
folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def release_model():
if "model_manager" in st.session_state:
st.session_state["model_manager"].to("cpu")
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
def load_model(model_type, model_path):
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.pipeline = pipeline
return model_manager, pipeline
def use_output_image_as_input(update=True):
# Search for input image
output_image_id = 0
selected_output_image = None
while True:
if f"use_output_as_input_{output_image_id}" not in st.session_state:
break
if st.session_state[f"use_output_as_input_{output_image_id}"]:
selected_output_image = st.session_state["output_images"][output_image_id]
break
output_image_id += 1
if update and selected_output_image is not None:
st.session_state["input_image"] = selected_output_image
return selected_output_image is not None
def apply_stroke_to_image(stroke_image, image):
image = np.array(image.convert("RGB")).astype(np.float32)
height, width, _ = image.shape
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
weight = stroke_image[:, :, -1:] / 255
stroke_image = stroke_image[:, :, :-1]
image = stroke_image * weight + image * (1 - weight)
image = np.clip(image, 0, 255).astype(np.uint8)
image = Image.fromarray(image)
return image
@st.cache_data
def image2bits(image):
image_byte = io.BytesIO()
image.save(image_byte, format="PNG")
image_byte = image_byte.getvalue()
return image_byte
def show_output_image(image):
st.image(image, use_column_width="always")
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
column_input, column_output = st.columns(2)
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
fixed_parameters = config[model_type]["fixed_parameters"]
model_path_list = ["None"] + load_model_list(model_type)
model_path = st.selectbox("Model path", model_path_list)
# Load the model
if model_path == "None":
# No models are selected. Release VRAM.
st.markdown("No models are selected.")
release_model()
else:
# A model is selected.
model_path = os.path.join(config[model_type]["model_folder"], model_path)
if st.session_state.get("loaded_model_path", "") != model_path:
# The loaded model is not the selected model. Reload it.
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
release_model()
model_manager, pipeline = load_model(model_type, model_path)
st.markdown("Done.")
else:
# The loaded model is not the selected model. Fetch it from `st.session_state`.
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
st.markdown("Done.")
# Show parameters
with st.expander("Prompt", expanded=True):
prompt = st.text_area("Positive prompt")
if "negative_prompt" in fixed_parameters:
negative_prompt = fixed_parameters["negative_prompt"]
else:
negative_prompt = st.text_area("Negative prompt")
if "cfg_scale" in fixed_parameters:
cfg_scale = fixed_parameters["cfg_scale"]
else:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
with st.expander("Image", expanded=True):
if "num_inference_steps" in fixed_parameters:
num_inference_steps = fixed_parameters["num_inference_steps"]
else:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
if "height" in fixed_parameters:
height = fixed_parameters["height"]
else:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
if "width" in fixed_parameters:
width = fixed_parameters["width"]
else:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
num_images = st.number_input("Number of images", value=2)
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
# Other fixed parameters
denoising_strength = 1.0
repetition = 1
# Show input image
with column_input:
with st.expander("Input image (Optional)", expanded=True):
with st.container(border=True):
column_white_board, column_upload_image = st.columns([1, 2])
with column_white_board:
create_white_board = st.button("Create white board")
delete_input_image = st.button("Delete input image")
with column_upload_image:
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
if upload_image is not None:
st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
elif create_white_board:
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
else:
use_output_image_as_input()
if delete_input_image and "input_image" in st.session_state:
del st.session_state.input_image
if delete_input_image and "upload_image" in st.session_state:
del st.session_state.upload_image
input_image = st.session_state.get("input_image", None)
if input_image is not None:
with st.container(border=True):
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
with column_drawing_mode:
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
with column_color_1:
stroke_color = st.color_picker("Stroke color")
with column_color_2:
fill_color = st.color_picker("Fill color")
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
with st.container(border=True):
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
with st.container(border=True):
input_width, input_height = input_image.size
canvas_result = st_canvas(
fill_color=fill_color,
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="rgba(255, 255, 255, 0)",
background_image=input_image,
update_streamlit=True,
height=int(512 / input_width * input_height),
width=512,
drawing_mode=drawing_mode,
key="canvas"
)
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
local_prompts, masks, mask_scales = [], [], []
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
painter_layers_json_data = []
for painter_tab_id in range(num_painter_layer):
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
canvas_result_local = st_canvas(
fill_color="#000000",
stroke_width=stroke_width,
stroke_color="#000000",
background_color="rgba(255, 255, 255, 0)",
background_image=white_board,
update_streamlit=True,
height=512,
width=512,
drawing_mode="freedraw",
key=f"canvas_{painter_tab_id}"
)
if canvas_result_local.json_data is not None:
painter_layers_json_data.append(canvas_result_local.json_data.copy())
painter_layers_json_data[-1]["prompt"] = local_prompt
if enable_local_prompt:
local_prompts.append(local_prompt)
if canvas_result_local.image_data is not None:
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
else:
mask = white_board
mask = Image.fromarray(255 - np.array(mask))
masks.append(mask)
mask_scales.append(mask_scale)
save_painter_layers = st.button("Save painter layers")
if save_painter_layers:
os.makedirs("data/painter_layers", exist_ok=True)
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
with open(json_file_path, "w") as f:
json.dump(painter_layers_json_data, f, indent=4)
st.markdown(f"Painter layers are saved in {json_file_path}.")
with column_output:
run_button = st.button("Generate image", type="primary")
auto_update = st.checkbox("Auto update", value=False)
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
image_columns = st.columns(num_image_columns)
# Run
if (run_button or auto_update) and model_path != "None":
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
output_images = []
for image_id in range(num_images * repetition):
if use_fixed_seed:
torch.manual_seed(seed + image_id)
else:
torch.manual_seed(np.random.randint(0, 10**9))
if image_id >= num_images:
input_image = output_images[image_id - num_images]
with image_columns[image_id % num_image_columns]:
progress_bar_st = st.progress(0.0)
image = pipeline(
prompt, negative_prompt=negative_prompt,
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
height=height, width=width,
input_image=input_image, denoising_strength=denoising_strength,
progress_bar_st=progress_bar_st
)
output_images.append(image)
progress_bar_st.progress(1.0)
show_output_image(image)
st.session_state["output_images"] = output_images
elif "output_images" in st.session_state:
for image_id in range(len(st.session_state.output_images)):
with image_columns[image_id % num_image_columns]:
image = st.session_state.output_images[image_id]
progress_bar = st.progress(1.0)
show_output_image(image)
if "upload_image" in st.session_state and use_output_image_as_input(update=False):
st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")

View File

@@ -0,0 +1,197 @@
import streamlit as st
st.set_page_config(layout="wide")
from diffsynth import SDVideoPipelineRunner
import os
import numpy as np
def load_model_list(folder):
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
file_list = sorted(file_list)
return file_list
def match_processor_id(model_name, supported_processor_id_list):
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
for processor_id in sorted_processor_id:
if processor_id in model_name:
return supported_processor_id_list.index(processor_id) + 1
return 0
config = {
"models": {
"model_list": [],
"textual_inversion_folder": "models/textual_inversion",
"device": "cuda",
"lora_alphas": [],
"controlnet_units": []
},
"data": {
"input_frames": None,
"controlnet_frames": [],
"output_folder": "output",
"fps": 60
},
"pipeline": {
"seed": 0,
"pipeline_inputs": {}
}
}
with st.expander("Model", expanded=True):
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
if stable_diffusion_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
if animatediff_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
column_lora, column_lora_alpha = st.columns([2, 1])
with column_lora:
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
with column_lora_alpha:
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
if sd_lora_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
config["models"]["lora_alphas"].append(lora_alpha)
with st.expander("Data", expanded=True):
with st.container(border=True):
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0)
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16)
if input_video != "":
config["data"]["input_frames"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
with st.container(border=True):
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
fps = st.number_input("FPS", value=60)
config["data"]["output_folder"] = output_video
config["data"]["fps"] = fps
with st.expander("ControlNet Units", expanded=True):
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
for controlnet_id in range(len(controlnet_units)):
with controlnet_units[controlnet_id]:
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
key=f"controlnet_ckpt_{controlnet_id}")
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
disabled=controlnet_ckpt == "None",
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
if not use_input_video_as_controlnet_input:
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0,
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16,
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
if input_video != "":
config["data"]["input_video"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
if controlnet_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
config["models"]["controlnet_units"].append({
"processor_id": processor_id,
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
"scale": controlnet_scale,
})
if use_input_video_as_controlnet_input:
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
else:
config["data"]["controlnet_frames"].append({
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
})
with st.container(border=True):
with st.expander("Seed", expanded=True):
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
else:
seed = np.random.randint(0, 10**9)
with st.expander("Textual Guidance", expanded=True):
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
column_cfg_scale, column_clip_skip = st.columns(2)
with column_cfg_scale:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
with column_clip_skip:
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
with st.expander("Denoising", expanded=True):
column_num_inference_steps, column_denoising_strength = st.columns(2)
with column_num_inference_steps:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
with column_denoising_strength:
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
with st.expander("Efficiency", expanded=False):
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
animatediff_stride = st.slider("Animatediff stride",
min_value=1,
max_value=max(2, animatediff_batch_size),
value=max(1, animatediff_batch_size // 2),
step=1)
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
config["pipeline"]["seed"] = seed
config["pipeline"]["pipeline_inputs"] = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"clip_skip": clip_skip,
"denoising_strength": denoising_strength,
"num_inference_steps": num_inference_steps,
"animatediff_batch_size": animatediff_batch_size,
"animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size,
"controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
run_button = st.button("Run☢", type="primary")
if run_button:
SDVideoPipelineRunner(in_streamlit=True).run(config)

View File

@@ -1 +1,6 @@
from .core import *
from .data import *
from .models import *
from .prompters import *
from .schedulers import *
from .pipelines import *
from .controlnets import *

View File

@@ -1,2 +0,0 @@
from .model_configs import MODEL_CONFIGS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS

View File

@@ -0,0 +1,358 @@
from typing_extensions import Literal, TypeAlias
from ..models.sd_text_encoder import SDTextEncoder
from ..models.sd_unet import SDUNet
from ..models.sd_vae_encoder import SDVAEEncoder
from ..models.sd_vae_decoder import SDVAEDecoder
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from ..models.sdxl_unet import SDXLUNet
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from ..models.sd3_dit import SD3DiT
from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder
from ..models.sd_controlnet import SDControlNet
from ..models.sdxl_controlnet import SDXLControlNetUnion
from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel
from ..models.svd_image_encoder import SVDImageEncoder
from ..models.svd_unet import SVDUNet
from ..models.svd_vae_decoder import SVDVAEDecoder
from ..models.svd_vae_encoder import SVDVAEEncoder
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
]
preset_models_on_huggingface = {
"HunyuanDiT": [
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
"stable-video-diffusion-img2vid-xt": [
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
# FLUX
"FLUX.1-dev": [
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
"HunyuanDiT": [
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
# Stable Video Diffusion
"stable-video-diffusion-img2vid-xt": [
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
# ExVideo
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
"AingDiffusion_v12": [
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
],
"Flat2DAnimerge_v45Sharp": [
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
"ControlNet_union_sdxl_promax": [
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# RIFE
"RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
],
# Qwen Prompt
"QwenPrompt": [
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Omost prompt
"OmostPrompt":[
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
# Translator
"opus-mt-zh-en": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
# FLUX
"FLUX.1-dev": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
],
# RIFE
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
# CogVideo
"CogVideoX-5B": [
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",
"Flat2DAnimerge_v45Sharp",
"TextualInversion_VeryBadImageNegative_v1.3",
"StableDiffusionXL_v1",
"BluePencilXL_v200",
"StableDiffusionXL_Turbo",
"ControlNet_v11f1p_sd15_depth",
"ControlNet_v11p_sd15_softedge",
"ControlNet_v11f1e_sd15_tile",
"ControlNet_v11p_sd15_lineart",
"AnimateDiff_v2",
"AnimateDiff_xl_beta",
"RIFE",
"BeautifulPrompt",
"opus-mt-zh-en",
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
"CogVideoX-5B",
]

View File

@@ -1,594 +0,0 @@
qwen_image_series = [
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8004730443f55db63092006dd9f7110e",
"model_name": "qwen_image_text_encoder",
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
"model_name": "qwen_image_blockwise_controlnet",
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
"model_name": "qwen_image_blockwise_controlnet",
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
"extra_kwargs": {"additional_in_dim": 4},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
"model_name": "siglip2_image_encoder",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
"model_hash": "5722b5c873720009de96422993b15682",
"model_name": "dinov3_image_encoder",
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
},
{
# Example:
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
"model_name": "qwen_image_image2lora_coarse",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
},
{
# Example:
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
"model_name": "qwen_image_image2lora_fine",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
},
{
# Example:
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
"model_name": "qwen_image_image2lora_style",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
]
wan_series = [
{
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
"model_name": "wan_video_text_encoder",
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
"model_name": "wan_video_vae",
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
},
{
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
},
{
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
"model_name": "wan_video_vap",
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
"model_hash": "5941c53e207d62f20f9025686193c40b",
"model_name": "wan_video_image_encoder",
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
"model_hash": "dbd5ec76bbf977983f972c151d545389",
"model_name": "wan_video_motion_controller",
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "9269f8db9040a9d860eaca435be61814",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "349723183fc063b2bfc10bb2835cf677",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "70ddad9d3a133785da5ea371aae09504",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
"model_name": "wan_video_vace",
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
"model_name": "wan_video_vace",
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
"model_name": "wan_video_animate_adapter",
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
},
{
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "2267d489f0ceb9f21836532952852ee5",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "966cffdcc52f9c46c391768b27637614",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
"model_name": "wan_video_vae",
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
"model_hash": "06be60f3a4526586d8431cd038a71486",
"model_name": "wans2v_audio_encoder",
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
},
]
flux_series = [
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
"model_hash": "a29710fea6dddb0314663ee823598e50",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Supported due to historical reasons.
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "flux_text_encoder_clip",
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
"model_name": "flux_text_encoder_t5",
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
"model_name": "flux_vae_encoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
"model_name": "flux_vae_decoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
"model_hash": "0629116fce1472503a66992f96f3eb1a",
"model_name": "flux_value_controller",
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
},
{
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "52357cb26250681367488a8954c271e8",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
},
{
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "78d18b9101345ff695f312e7e62538c0",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
},
{
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "b001c89139b5f053c715fe772362dd2a",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_single_blocks": 0},
},
{
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
"model_name": "infiniteyou_image_projector",
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
},
{
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
"model_name": "flux_lora_encoder",
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
"model_name": "flux_lora_patcher",
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
"model_name": "nexus_gen_llm",
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
"model_name": "nexus_gen_editing_adapter",
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
"model_name": "nexus_gen_generation_adapter",
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
"model_name": "flux_ipadapter",
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
"model_name": "siglip_vision_model",
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
},
{
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
"model_name": "step1x_connector",
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
},
{
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
"extra_kwargs": {"disable_guidance_embedder": True},
},
{
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
]
flux2_series = [
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
"model_name": "flux2_text_encoder",
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "c54288e3ee12ca215898840682337b95",
"model_name": "flux2_vae",
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "3bde7b817fec8143028b6825a63180df",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "8B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
},
]
z_image_series = [
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
"model_name": "flux_vae_encoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
"model_name": "flux_vae_decoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"extra_kwargs": {"siglip_feat_dim": 1152},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
{
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
},
{
# Example: ???
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
"model_name": "z_image_image2lora_style",
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -1,213 +0,0 @@
flux_general_vram_config = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
}
VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_dit.WanModel": {
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_mot.MotWanModel": {
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vace.VaceWanModel": {
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vae.WanVideoVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_dit.FluxDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_dit.Flux2DiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_vae.Flux2VAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_dit.ZImageDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
}

View File

@@ -0,0 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
from .processors import Annotator

View File

@@ -0,0 +1,62 @@
import torch
import numpy as np
from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
class ControlNetUnit:
def __init__(self, processor, model, scale=1.0):
self.processor = processor
self.model = model
self.scale = scale
class MultiControlNetManager:
def __init__(self, controlnet_units=[]):
self.processors = [unit.processor for unit in controlnet_units]
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def cpu(self):
for model in self.models:
model.cpu()
def to(self, device):
for model in self.models:
model.to(device)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
processed_image = torch.concat([
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
for image_ in processed_image
], dim=0)
return processed_image
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32, **kwargs
):
res_stack = None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
processor_id=processor.processor_id
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack

View File

@@ -0,0 +1,51 @@
from typing_extensions import Literal, TypeAlias
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
)
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def __call__(self, image):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {
"include_body": True,
"include_hand": True,
"include_face": True
}
else:
kwargs = {}
if self.processor is not None:
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
return image

View File

@@ -1,6 +0,0 @@
from .attention import *
from .data import *
from .gradient import *
from .loader import *
from .vram import *
from .device import *

View File

@@ -1 +0,0 @@
from .attention import attention_forward

View File

@@ -1,121 +0,0 @@
import torch, os
from einops import rearrange
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
import xformers.ops as xops
XFORMERS_AVAILABLE = True
except ModuleNotFoundError:
XFORMERS_AVAILABLE = False
def initialize_attention_priority():
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
elif FLASH_ATTN_3_AVAILABLE:
return "flash_attention_3"
elif FLASH_ATTN_2_AVAILABLE:
return "flash_attention_2"
elif SAGE_ATTN_AVAILABLE:
return "sage_attention"
elif XFORMERS_AVAILABLE:
return "xformers"
else:
return "torch"
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if q_pattern != required_in_pattern:
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
return q, k, v
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if out_pattern != required_out_pattern:
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
return out
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
if isinstance(out, tuple):
out = out[0]
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = sageattn(q, k, v, sm_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = xops.memory_efficient_attention(q, k, v, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
if compatibility_mode or (attn_mask is not None):
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
else:
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "sage_attention":
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "xformers":
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
else:
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)

View File

@@ -1 +0,0 @@
from .unified_dataset import UnifiedDataset

View File

@@ -1,220 +0,0 @@
import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
class DataProcessingPipeline:
def __init__(self, operators=None):
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
def __call__(self, data):
for operator in self.operators:
data = operator(data)
return data
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline(self.operators + pipe.operators)
class DataProcessingOperator:
def __call__(self, data):
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline([self]).__rshift__(pipe)
class DataProcessingOperatorRaw(DataProcessingOperator):
def __call__(self, data):
return data
class ToInt(DataProcessingOperator):
def __call__(self, data):
return int(data)
class ToFloat(DataProcessingOperator):
def __call__(self, data):
return float(data)
class ToStr(DataProcessingOperator):
def __init__(self, none_value=""):
self.none_value = none_value
def __call__(self, data):
if data is None: data = self.none_value
return str(data)
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str):
image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image
class ImageCropAndResize(DataProcessingOperator):
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
self.height = height
self.width = width
self.max_pixels = max_pixels
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.height is None or self.width is None:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def __call__(self, data: Image.Image):
image = self.crop_and_resize(data, *self.get_height_width(data))
return image
class ToList(DataProcessingOperator):
def __call__(self, data):
return [data]
class LoadVideo(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
reader = imageio.get_reader(data)
num_frames = self.get_num_frames(reader)
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.frame_processor(frame)
frames.append(frame)
reader.close()
return frames
class SequencialProcess(DataProcessingOperator):
def __init__(self, operator=lambda x: x):
self.operator = operator
def __call__(self, data):
return [self.operator(i) for i in data]
class LoadGIF(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, path):
num_frames = self.num_frames
images = iio.imread(path, mode="RGB")
if len(images) < num_frames:
num_frames = len(images)
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
num_frames = self.get_num_frames(data)
frames = []
images = iio.imread(data, mode="RGB")
for img in images:
frame = Image.fromarray(img)
frame = self.frame_processor(frame)
frames.append(frame)
if len(frames) >= num_frames:
break
return frames
class RouteByExtensionName(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data: str):
file_ext_name = data.split(".")[-1].lower()
for ext_names, operator in self.operator_map:
if ext_names is None or file_ext_name in ext_names:
return operator(data)
raise ValueError(f"Unsupported file: {data}")
class RouteByType(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data):
for dtype, operator in self.operator_map:
if dtype is None or isinstance(data, dtype):
return operator(data)
raise ValueError(f"Unsupported data: {data}")
class LoadTorchPickle(DataProcessingOperator):
def __init__(self, map_location="cpu"):
self.map_location = map_location
def __call__(self, data):
return torch.load(data, map_location=self.map_location, weights_only=False)
class ToAbsolutePath(DataProcessingOperator):
def __init__(self, base_path=""):
self.base_path = base_path
def __call__(self, data):
return os.path.join(self.base_path, data)
class LoadAudio(DataProcessingOperator):
def __init__(self, sr=16000):
self.sr = sr
def __call__(self, data: str):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio

View File

@@ -1,116 +0,0 @@
from .operators import *
import torch, json, pandas
class UnifiedDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
repeat=1,
data_file_keys=tuple(),
main_data_operator=lambda x: x,
special_operator_map=None,
max_data_items=None,
):
self.base_path = base_path
self.metadata_path = metadata_path
self.repeat = repeat
self.data_file_keys = data_file_keys
self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = []
self.cached_data = []
self.load_from_cache = metadata_path is None
self.load_metadata(metadata_path)
@staticmethod
def default_image_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
])
@staticmethod
def default_video_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
num_frames=81, time_division_factor=4, time_division_remainder=1,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
(("gif",), LoadGIF(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
])),
])
def search_for_cached_data_files(self, path):
for file_name in os.listdir(path):
subpath = os.path.join(path, file_name)
if os.path.isdir(subpath):
self.search_for_cached_data_files(subpath)
elif subpath.endswith(".pth"):
self.cached_data.append(subpath)
def load_metadata(self, metadata_path):
if metadata_path is None:
print("No metadata_path. Searching for cached data files.")
self.search_for_cached_data_files(self.base_path)
print(f"{len(self.cached_data)} cached data files found.")
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
elif metadata_path.endswith(".jsonl"):
metadata = []
with open(metadata_path, 'r') as f:
for line in f:
metadata.append(json.loads(line.strip()))
self.data = metadata
else:
metadata = pandas.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def __getitem__(self, data_id):
if self.load_from_cache:
data = self.cached_data[data_id % len(self.cached_data)]
data = self.cached_data_operator(data)
else:
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
if key in self.special_operator_map:
data[key] = self.special_operator_map[key](data[key])
elif key in self.data_file_keys:
data[key] = self.main_data_operator(data[key])
return data
def __len__(self):
if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat
else:
return len(self.data) * self.repeat
def check_data_equal(self, data1, data2):
# Debug only
if len(data1) != len(data2):
return False
for k in data1:
if data1[k] != data2[k]:
return False
return True

View File

@@ -1,2 +0,0 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -1,107 +0,0 @@
import importlib
import torch
from typing import Any
def is_torch_npu_available():
return importlib.util.find_spec("torch_npu") is not None
IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
if IS_NPU_AVAILABLE:
import torch_npu
torch.npu.config.allow_internal_format = False
def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
return device
def get_torch_device() -> Any:
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
device_name = get_device_type()
try:
return getattr(torch, device_name)
except AttributeError:
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
return torch.cuda
def get_device_id() -> int:
"""Get current device id based on device type."""
return get_torch_device().current_device()
def get_device_name() -> str:
"""Get current device name based on device type."""
return f"{get_device_type()}:{get_device_id()}"
def synchronize() -> None:
"""Execute torch synchronize operation."""
get_torch_device().synchronize()
def empty_cache() -> None:
"""Execute torch empty cache operation."""
get_torch_device().empty_cache()
def get_nccl_backend() -> str:
"""Return distributed communication backend type based on device type."""
if IS_CUDA_AVAILABLE:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
def enable_high_precision_for_bf16():
"""
Set high accumulation dtype for matmul and reduction.
"""
if IS_CUDA_AVAILABLE:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
if IS_NPU_AVAILABLE:
torch.npu.matmul.allow_tf32 = False
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
def parse_device_type(device):
if isinstance(device, str):
if device.startswith("cuda"):
return "cuda"
elif device.startswith("npu"):
return "npu"
else:
return "cpu"
elif isinstance(device, torch.device):
return device.type
def parse_nccl_backend(device_type):
if device_type == "cuda":
return "nccl"
elif device_type == "npu":
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
def get_available_device_type():
return get_device_type()

View File

@@ -1 +0,0 @@
from .gradient_checkpoint import gradient_checkpoint_forward

View File

@@ -1,34 +0,0 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output

View File

@@ -1,3 +0,0 @@
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
from .model import load_model, load_model_with_disk_offload
from .config import ModelConfig

View File

@@ -1,118 +0,0 @@
import torch, glob, os
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
from typing import Optional
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_source: str = None
local_model_path: str = None
skip_download: bool = None
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
onload_device: Optional[Union[str, torch.device]] = None
onload_dtype: Optional[torch.dtype] = None
preparing_device: Optional[Union[str, torch.device]] = None
preparing_dtype: Optional[torch.dtype] = None
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
if self.origin_file_pattern is None or self.origin_file_pattern == "":
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
else:
return self.origin_file_pattern
def parse_download_source(self):
if self.download_source is None:
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
else:
return "modelscope"
else:
return self.download_source
def parse_skip_download(self):
if self.skip_download is None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
return True
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
return False
else:
return False
else:
return self.skip_download
def download(self):
origin_file_pattern = self.parse_original_file_pattern()
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
download_source = self.parse_download_source()
if download_source.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif download_source.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=origin_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
else:
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
def require_downloading(self):
if self.path is not None:
return False
skip_download = self.parse_skip_download()
return not skip_download
def reset_local_model_path(self):
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
elif self.local_model_path is None:
self.local_model_path = "./models"
def download_if_necessary(self):
self.check_input()
self.reset_local_model_path()
if self.require_downloading():
self.download()
if self.path is None:
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
def vram_config(self):
return {
"offload_device": self.offload_device,
"offload_dtype": self.offload_dtype,
"onload_device": self.onload_device,
"onload_dtype": self.onload_dtype,
"preparing_device": self.preparing_device,
"preparing_dtype": self.preparing_dtype,
"computation_device": self.computation_device,
"computation_dtype": self.computation_dtype,
}

View File

@@ -1,121 +0,0 @@
from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
state_dict = {}
with safe_open(file_path, framework="pt", device=str(device)) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if len(state_dict) == 1:
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "module" in state_dict:
state_dict = state_dict["module"]
elif "model_state" in state_dict:
state_dict = state_dict["model_state"]
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_keys_dict(file_path):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_keys_dict(file_path_))
return state_dict
if file_path.endswith(".safetensors"):
return load_keys_dict_from_safetensors(file_path)
else:
return load_keys_dict_from_bin(file_path)
def load_keys_dict_from_safetensors(file_path):
keys_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
keys_dict[k] = f.get_slice(k).get_shape()
return keys_dict
def convert_state_dict_to_keys_dict(state_dict):
keys_dict = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
keys_dict[k] = list(v.shape)
else:
keys_dict[k] = convert_state_dict_to_keys_dict(v)
return keys_dict
def load_keys_dict_from_bin(file_path):
state_dict = load_state_dict_from_bin(file_path)
keys_dict = convert_state_dict_to_keys_dict(state_dict)
return keys_dict
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, dict):
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
else:
if with_shape:
shape = "_".join(map(str, list(value)))
keys.append(key + ":" + shape)
keys.append(key)
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_model_file(path, with_shape=True):
keys_dict = load_keys_dict(path)
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -1,102 +0,0 @@
from ..vram.initialization import skip_model_initialization
from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
if module_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
else:
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
else:
# Why do we use `DiskMap`?
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
# Why do we use `state_dict_converter`?
# Some models are saved in complex formats,
# and we need to convert the state dict into the appropriate format.
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
# Because at this stage, model parameters are partitioned across multiple GPUs.
# Loading them directly could lead to excessive GPU memory consumption.
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
else:
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
model = model.to(dtype=torch_dtype, device=device)
if hasattr(model, "eval"):
model = model.eval()
return model
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
if isinstance(path, str):
path = [path]
config = {} if config is None else config
with skip_model_initialization():
model = model_class(**config)
if hasattr(model, "eval"):
model = model.eval()
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch_dtype,
"computation_device": device,
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model
def get_init_context(torch_dtype, device):
if is_deepspeed_zero3_enabled():
from transformers.modeling_utils import set_zero3_state
import deepspeed
# Why do we use "deepspeed.zero.Init"?
# Weight segmentation of the model can be performed on the CPU side
# and loading the segmented weights onto the computing card
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
else:
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
init_contexts = [skip_model_initialization()]
return init_contexts

View File

@@ -1,2 +0,0 @@
from .initialization import skip_model_initialization
from .layers import *

View File

@@ -1,93 +0,0 @@
from safetensors import safe_open
import torch, os
class SafetensorsCompatibleTensor:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return list(self.tensor.shape)
class SafetensorsCompatibleBinaryLoader:
def __init__(self, path, device):
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
self.state_dict = torch.load(path, weights_only=True, map_location=device)
def keys(self):
return self.state_dict.keys()
def get_tensor(self, name):
return self.state_dict[name]
def get_slice(self, name):
return SafetensorsCompatibleTensor(self.state_dict[name])
class DiskMap:
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
self.path = path if isinstance(path, list) else [path]
self.device = device
self.torch_dtype = torch_dtype
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
else:
self.buffer_size = buffer_size
self.files = []
self.flush_files()
self.name_map = {}
for file_id, file in enumerate(self.files):
for name in file.keys():
self.name_map[name] = file_id
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
def flush_files(self):
if len(self.files) == 0:
for path in self.path:
if path.endswith(".safetensors"):
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
else:
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
else:
for i, path in enumerate(self.path):
if path.endswith(".safetensors"):
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
self.num_params = 0
def __getitem__(self, name):
if self.rename_dict is not None: name = self.rename_dict[name]
file_id = self.name_map[name]
param = self.files[file_id].get_tensor(name)
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
param = param.to(self.torch_dtype)
if isinstance(param, torch.Tensor) and param.device == "cpu":
param = param.clone()
if isinstance(param, torch.Tensor):
self.num_params += param.numel()
if self.num_params > self.buffer_size:
self.flush_files()
return param
def fetch_rename_dict(self, state_dict_converter):
if state_dict_converter is None:
return None
state_dict = {}
for file in self.files:
for name in file.keys():
state_dict[name] = name
state_dict = state_dict_converter(state_dict)
return state_dict
def __iter__(self):
if self.rename_dict is not None:
return self.rename_dict.__iter__()
else:
return self.name_map.__iter__()
def __contains__(self, x):
if self.rename_dict is not None:
return x in self.rename_dict
else:
return x in self.name_map

View File

@@ -1,21 +0,0 @@
import torch
from contextlib import contextmanager
@contextmanager
def skip_model_initialization(device=torch.device("meta")):
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
old_register_parameter = torch.nn.Module.register_parameter
torch.nn.Module.register_parameter = register_empty_parameter
try:
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter

View File

@@ -1,479 +0,0 @@
import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module):
def __init__(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
super().__init__()
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.state = 0
self.name = ""
self.computation_device_type = parse_device_type(self.computation_device)
def set_dtype_and_device(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
self.offload_dtype = offload_dtype or computation_dtype
self.offload_device = offload_device or computation_dtype
self.onload_dtype = onload_dtype or computation_dtype
self.onload_device = onload_device or computation_dtype
self.preparing_dtype = preparing_dtype or computation_dtype
self.preparing_device = preparing_device or computation_dtype
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
def cast_to(self, weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def check_free_vram(self):
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def param_name(self, name):
if self.name == "":
return name
else:
return self.name + "." + name
class AutoWrappedModule(AutoTorchModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.module = module
if offload_dtype == "disk":
self.name = name
self.disk_map = disk_map
self.required_params = [name for name, _ in self.module.named_parameters()]
self.disk_offload = True
else:
self.disk_offload = False
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True)
module.to(dtype=torch_dtype, device=device)
return module
def offload_to_disk(self, model: torch.nn.Module):
for buf in model.buffers():
# If there are some parameters are registed in buffers (not in state dict),
# We cannot offload the model.
for children in model.children():
self.offload_to_disk(children)
break
else:
model.to("meta")
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.offload_to_disk(self.module)
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def cast_to(self, module, dtype, device):
return copy.deepcopy(module).to(dtype=dtype, device=device)
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
module = self.module
elif self.disk_offload and device == "disk":
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
else:
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
return module
def forward(self, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
module = self.computation()
return module(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedNonRecurseModule(AutoWrappedModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
module,
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
name,
disk_map,
**kwargs
)
if self.disk_offload:
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True, strict=False)
return module
def offload_to_disk(self, model: torch.nn.Module):
for name in self.required_params:
getattr(self, name).to("meta")
def cast_to(self, module, dtype, device):
# Parameter casting is implemented in the model architecture.
return module
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(
self,
module: torch.nn.Linear,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
with skip_model_initialization():
super().__init__(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
)
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.weight = module.weight
self.bias = module.bias
self.state = 0
self.name = name
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk":
self.disk_map = disk_map
self.disk_offload = True
else:
self.disk_offload = False
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def load_from_disk(self, torch_dtype, device, assign=True):
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
if assign:
state_dict = {"weight": weight}
if bias is not None: state_dict["bias"] = bias
self.load_state_dict(state_dict, assign=True)
return weight, bias
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.to("meta")
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.disk_offload and device == "disk":
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
else:
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
return weight, bias
def linear_forward(self, x, weight, bias):
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
return out
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out
def forward(self, x, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
weight, bias = self.computation()
out = self.linear_forward(x, weight, bias)
if len(self.lora_A_weights) > 0:
out = self.lora_forward(x, out)
return out
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
if isinstance(model, AutoWrappedNonRecurseModule):
model = model.module
for name, module in model.named_children():
layer_name = name if name_prefix == "" else name_prefix + "." + name
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
if isinstance(module_, AutoWrappedNonRecurseModule):
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
setattr(model, name, module_)
break
else:
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
def fill_vram_config(model, vram_config):
vram_config_ = vram_config.copy()
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
vram_config_["onload_device"] = vram_config["computation_device"]
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
vram_config_["preparing_device"] = vram_config["computation_device"]
for k in vram_config:
if vram_config[k] != vram_config_[k]:
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
break
return vram_config_
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
for source_module, target_module in module_map.items():
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
if isinstance(model, source_module):
vram_config = fill_vram_config(model, vram_config)
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
break
else:
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
model.vram_management_enabled = True
return model

View File

@@ -0,0 +1 @@
from .video import VideoData, save_video, save_frames

View File

@@ -0,0 +1,35 @@
import torch, os
from torchvision import transforms
import pandas as pd
from PIL import Image
class TextImageDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
self.steps_per_epoch = steps_per_epoch
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.image_processor = transforms.Compose(
[
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB")
image = self.image_processor(image)
return {"text": text, "image": image}
def __len__(self):
return self.steps_per_epoch

View File

@@ -2,8 +2,6 @@ import imageio, os
import numpy as np
from PIL import Image
from tqdm import tqdm
import subprocess
import shutil
class LowMemoryVideo:
@@ -137,8 +135,8 @@ class VideoData:
frame.save(os.path.join(folder, f"{i}.png"))
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
def save_video(frames, save_path, fps, quality=9):
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame)
writer.append_data(frame)
@@ -148,70 +146,3 @@ def save_frames(frames, save_path):
os.makedirs(save_path, exist_ok=True)
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
frame.save(os.path.join(save_path, f"{i}.png"))
def merge_video_audio(video_path: str, audio_path: str):
# TODO: may need a in-python implementation to avoid subprocess dependency
"""
Merge the video and audio into a new video, with the duration set to the shorter of the two,
and overwrite the original video file.
Parameters:
video_path (str): Path to the original video file
audio_path (str): Path to the audio file
"""
# check
if not os.path.exists(video_path):
raise FileNotFoundError(f"video file {video_path} does not exist")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"audio file {audio_path} does not exist")
base, ext = os.path.splitext(video_path)
temp_output = f"{base}_temp{ext}"
try:
# create ffmpeg command
command = [
'ffmpeg',
'-y', # overwrite
'-i',
video_path,
'-i',
audio_path,
'-c:v',
'copy', # copy video stream
'-c:a',
'aac', # use AAC audio encoder
'-b:a',
'192k', # set audio bitrate (optional)
'-map',
'0:v:0', # select the first video stream
'-map',
'1:a:0', # select the first audio stream
'-shortest', # choose the shortest duration
temp_output
]
# execute the command
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# check result
if result.returncode != 0:
error_msg = f"FFmpeg execute failed: {result.stderr}"
print(error_msg)
raise RuntimeError(error_msg)
shutil.move(temp_output, video_path)
print(f"Merge completed, saved to {video_path}")
except Exception as e:
if os.path.exists(temp_output):
os.remove(temp_output)
print(f"merge_video_audio failed with error: {e}")
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
save_video(frames, save_path, fps, quality, ffmpeg_params)
merge_video_audio(save_path, audio_path)

View File

@@ -1,6 +0,0 @@
from .flow_match import FlowMatchScheduler
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
from .runner import launch_training_task, launch_data_process_task
from .parsers import *
from .loss import *

View File

@@ -1,451 +0,0 @@
from PIL import Image
import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
output_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.output_params = output_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def fetch_input_params(self):
params = []
if self.input_params is not None:
for param in self.input_params:
params.append(param)
if self.input_params_posi is not None:
for _, param in self.input_params_posi.items():
params.append(param)
if self.input_params_nega is not None:
for _, param in self.input_params_nega.items():
params.append(param)
params = sorted(list(set(params)))
return params
def fetch_output_params(self):
params = []
if self.output_params is not None:
for param in self.output_params:
params.append(param)
return params
def process(self, pipe, **kwargs) -> dict:
return {}
def post_process(self, pipe, **kwargs) -> dict:
return {}
class BasePipeline(torch.nn.Module):
def __init__(
self,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# VRAM management
self.vram_management_enabled = False
# Pipeline Unit Runner
self.unit_runner = PipelineUnitRunner()
# LoRA Loader
self.lora_loader = GeneralLoRALoader
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "offload"):
model.offload()
else:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "onload"):
model.onload()
else:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def get_vram(self):
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
if name.isdigit():
return self.get_module(model[int(name)], suffix)
else:
return self.get_module(getattr(model, name), suffix)
else:
return getattr(model, name)
def freeze_except(self, model_names):
self.eval()
self.requires_grad_(False)
for name in model_names:
module = self.get_module(self, name)
if module is None:
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
continue
module.train()
module.requires_grad_(True)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
def split_pipeline_units(self, model_names: list[str]):
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
def flush_vram_management_device(self, device):
for module in self.modules():
if isinstance(module, AutoTorchModule):
module.offload_device = device
module.onload_device = device
module.preparing_device = device
module.computation_device = device
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=None,
state_dict=None,
verbose=1,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
lora = lora_loader.convert_state_dict(lora)
if hotload is None:
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
if hotload:
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
updated_num = 0
for _, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
name = module.name
lora_a_name = f'{name}.lora_A.weight'
lora_b_name = f'{name}.lora_B.weight'
if lora_a_name in lora and lora_b_name in lora:
updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self, verbose=1):
cleared_num = 0
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
if len(module.lora_A_weights) > 0:
cleared_num += 1
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
model_pool = ModelPool()
for model_config in model_configs:
model_config.download_if_necessary()
vram_config = model_config.vram_config()
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
vram_config["computation_device"] = vram_config["computation_device"] or self.device
model_pool.auto_load_model(
model_config.path,
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
)
return model_pool
def check_vram_management_state(self):
vram_management_enabled = False
for module in self.children():
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
vram_management_enabled = True
return vram_management_enabled
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
class PipelineUnitGraph:
def __init__(self):
pass
def build_edges(self, units: list[PipelineUnit]):
# Establish dependencies between units
# to search for subsequent related computation units.
last_compute_unit_id = {}
edges = []
for unit_id, unit in enumerate(units):
for input_param in unit.fetch_input_params():
if input_param in last_compute_unit_id:
edges.append((last_compute_unit_id[input_param], unit_id))
for output_param in unit.fetch_output_params():
last_compute_unit_id[output_param] = unit_id
return edges
def build_chains(self, units: list[PipelineUnit]):
# Establish updating chains for each variable
# to track their computation process.
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
params = sorted(list(set(params)))
chains = {param: [] for param in params}
for unit_id, unit in enumerate(units):
for param in unit.fetch_output_params():
chains[param].append(unit_id)
return chains
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
# Search for units that directly participate in the model's computation.
related_unit_ids = []
for unit_id, unit in enumerate(units):
for model_name in model_names:
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
related_unit_ids.append(unit_id)
break
return related_unit_ids
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
# Search for subsequent related computation units.
related_unit_ids = [unit_id for unit_id in start_unit_ids]
while True:
neighbors = []
for source, target in edges:
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
neighbors.append(target)
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
neighbors.append(source)
neighbors = sorted(list(set(neighbors)))
if len(neighbors) == 0:
break
else:
related_unit_ids.extend(neighbors)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
# If the input parameters of this subgraph are updated outside the subgraph,
# search for the units where these updates occur.
first_compute_unit_id = {}
for unit_id in related_unit_ids:
for param in units[unit_id].fetch_input_params():
if param not in first_compute_unit_id:
first_compute_unit_id[param] = unit_id
updating_unit_ids = []
for param in first_compute_unit_id:
unit_id = first_compute_unit_id[param]
chain = chains[param]
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
for unit_id_ in chain[chain.index(unit_id) + 1:]:
if unit_id_ not in related_unit_ids:
updating_unit_ids.append(unit_id_)
related_unit_ids.extend(updating_unit_ids)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
# Split the computation graph,
# separating all model-related computations.
related_unit_ids = self.search_direct_unit_ids(units, model_names)
edges = self.build_edges(units)
chains = self.build_chains(units)
while True:
num_related_unit_ids = len(related_unit_ids)
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
if len(related_unit_ids) == num_related_unit_ids:
break
else:
num_related_unit_ids = len(related_unit_ids)
related_units = [units[i] for i in related_unit_ids]
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
return related_units, unrelated_units
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega

View File

@@ -1,184 +0,0 @@
import torch, math
from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@staticmethod
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.003/1.002
sigma_max = 1.0
shift = 3 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 5 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
@staticmethod
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
shift_terminal = 0.02
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1 - sigmas
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
sigmas = 1 - (one_minus_z / scale_factor)
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
@staticmethod
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
sigma_min = 1 / num_inference_steps
sigma_max = 1.0
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
if dynamic_shift_len is None:
# If you ask me why I set mu=0.8,
# I can only say that it yields better training results.
mu = 0.8
else:
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 3 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
if target_timesteps is not None:
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
for timestep in target_timesteps:
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
def set_training_weight(self):
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
if len(self.timesteps) != 1000:
# This is an empirical formula.
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn(
num_inference_steps=num_inference_steps,
denoising_strength=denoising_strength,
**kwargs,
)
if training:
self.set_training_weight()
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -1,43 +0,0 @@
import os, torch
from accelerate import Accelerator
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
self.num_steps = 0
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)

View File

@@ -1,119 +0,0 @@
from .base_pipeline import BasePipeline
import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
class TrajectoryImitationLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, device):
import lpips # TODO: remove it
self.loss_fn = lpips.LPIPS(net='alex').to(device)
self.initialized = True
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
trajectory = [inputs_shared["latents"].clone()]
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
trajectory.append(inputs_shared["latents"].clone())
return pipe.scheduler.timesteps, trajectory
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
loss = 0
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
sigma = pipe.scheduler.sigmas[progress_id]
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
if progress_id + 1 >= len(pipe.scheduler.timesteps):
latents_ = trajectory_teacher[-1]
else:
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
inputs_shared["latents"] = trajectory_teacher[0]
pipe.scheduler.set_timesteps(num_inference_steps)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
image_pred = pipe.vae_decoder(inputs_shared["latents"])
image_real = pipe.vae_decoder(trajectory_teacher[-1])
loss = self.loss_fn(image_pred.float(), image_real.float())
return loss
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
if not self.initialized:
self.initialize(pipe.device)
with torch.no_grad():
pipe.scheduler.set_timesteps(8)
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss = loss_1 + loss_2
return loss

View File

@@ -1,70 +0,0 @@
import argparse
def add_dataset_base_config(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
return parser
def add_image_size_config(parser: argparse.ArgumentParser):
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
return parser
def add_video_size_config(parser: argparse.ArgumentParser):
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
return parser
def add_model_config(parser: argparse.ArgumentParser):
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
return parser
def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
return parser
def add_output_config(parser: argparse.ArgumentParser):
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
return parser
def add_lora_config(parser: argparse.ArgumentParser):
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
return parser
def add_gradient_config(parser: argparse.ArgumentParser):
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def add_general_config(parser: argparse.ArgumentParser):
parser = add_dataset_base_config(parser)
parser = add_model_config(parser)
parser = add_training_config(parser)
parser = add_output_config(parser)
parser = add_lora_config(parser)
parser = add_gradient_config(parser)
return parser

View File

@@ -1,72 +0,0 @@
import os, torch
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)

View File

@@ -1,263 +0,0 @@
import torch, json, os
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from peft import LoraConfig, inject_adapter_in_model
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
if lora_alpha is None:
lora_alpha = lora_rank
if isinstance(target_modules, list) and len(target_modules) == 1:
target_modules = target_modules[0]
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
if upcast_dtype is not None:
for param in model.parameters():
if param.requires_grad:
param.data = param.to(upcast_dtype)
return model
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if "lora_A.weight" in key or "lora_B.weight" in key:
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
new_state_dict[new_key] = value
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
new_state_dict[key] = value
return new_state_dict
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
trainable_param_names = self.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix):]
state_dict_[name] = param
state_dict = state_dict_
return state_dict
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
if data is None:
return data
elif isinstance(data, torch.Tensor):
data = data.to(device)
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
data = data.to(torch_float_dtype)
return data
elif isinstance(data, tuple):
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, list):
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, dict):
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
return data
else:
return data
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
if fp8:
return {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": device,
"onload_dtype": torch.float8_e4m3fn,
"onload_device": device,
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
}
elif offload:
return {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
"clear_parameters": True,
}
else:
return {}
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
fp8_models = [] if fp8_models is None else fp8_models.split(",")
offload_models = [] if offload_models is None else offload_models.split(",")
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
for path in model_paths:
vram_config = self.parse_vram_config(
fp8=path in fp8_models,
offload=path in offload_models,
device=device
)
model_configs.append(ModelConfig(path=path, **vram_config))
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths:
vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models,
device=device
)
config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
task="sft",
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Preset LoRA
if preset_lora_path is not None:
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
# FP8
# FP8 relies on a model-specific memory management scheme.
# It is delegated to the subclass.
# Add LoRA to the base models
if lora_base_model is not None and not task.endswith(":data_process"):
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
return
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
models_require_backward = []
if trainable_models is not None:
models_require_backward += trainable_models.split(",")
if lora_base_model is not None:
models_require_backward += [lora_base_model]
if task.endswith(":data_process"):
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
elif task.endswith(":train"):
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
return pipe
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
controlnet_keys_map = (
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
("controlnet_", "controlnet_inputs"),
)
controlnet_inputs = {}
for extra_input in extra_inputs:
for prefix, name in controlnet_keys_map:
if extra_input.startswith(prefix):
if name not in controlnet_inputs:
controlnet_inputs[name] = {}
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
break
else:
inputs_shared[extra_input] = data[extra_input]
for name, params in controlnet_inputs.items():
inputs_shared[name] = [ControlNetInput(**params)]
return inputs_shared

View File

@@ -0,0 +1,129 @@
import torch
from einops import repeat
from PIL import Image
import numpy as np
class ResidualDenseBlock(torch.nn.Module):
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(torch.nn.Module):
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
feat = self.lrelu(self.conv_up1(feat))
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
@staticmethod
def state_dict_converter():
return RRDBNetStateDictConverter()
class RRDBNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
return state_dict, {"upcast_to_float32": True}
class ESRGAN(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
@staticmethod
def from_model_manager(model_manager):
return ESRGAN(model_manager.fetch_model("esrgan"))
def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
return image
def process_images(self, images):
images = [self.process_image(image) for image in images]
images = torch.stack(images)
return images
def decode_images(self, images):
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
images = [Image.fromarray(image) for image in images]
return images
@torch.no_grad()
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
# Preprocess
input_tensor = self.process_images(images)
# Interpolate
output_tensor = []
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(
device=self.model.conv_first.weight.device,
dtype=self.model.conv_first.weight.dtype)
batch_output_tensor = self.model(batch_input_tensor)
output_tensor.append(batch_output_tensor.cpu())
# Output
output_tensor = torch.concat(output_tensor, dim=0)
# To images
output_images = self.decode_images(output_tensor)
return output_images

View File

@@ -0,0 +1,63 @@
from .runners.fast import TableManager, PyramidPatchMatcher
from PIL import Image
import numpy as np
import cupy as cp
class FastBlendSmoother:
def __init__(self):
self.batch_size = 8
self.window_size = 64
self.ebsynth_config = {
"minimum_patch_size": 5,
"threads_per_block": 8,
"num_iter": 5,
"gpu_id": 0,
"guide_weight": 10.0,
"initialize": "identity",
"tracking_window_size": 0,
}
@staticmethod
def from_model_manager(model_manager):
# TODO: fetch GPU ID from model_manager
return FastBlendSmoother()
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
frames_guide = [np.array(frame) for frame in frames_guide]
frames_style = [np.array(frame) for frame in frames_style]
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
return frames
def __call__(self, rendered_frames, original_frames=None, **kwargs):
frames = self.run(
original_frames, rendered_frames,
self.batch_size, self.window_size, self.ebsynth_config
)
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
return frames

View File

@@ -0,0 +1,397 @@
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
from .data import VideoData, get_video_fps, save_video, search_for_images
import os
import gradio as gr
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
frames_guide = VideoData(video_guide, video_guide_folder)
frames_style = VideoData(video_style, video_style_folder)
message = ""
if len(frames_guide) < len(frames_style):
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
frames_style.set_length(len(frames_guide))
elif len(frames_guide) > len(frames_style):
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
frames_guide.set_length(len(frames_style))
height_guide, width_guide = frames_guide.shape()
height_style, width_style = frames_style.shape()
if height_guide != height_style or width_guide != width_style:
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
frames_style.set_shape(height_guide, width_guide)
return frames_guide, frames_style, message
def smooth_video(
video_guide,
video_guide_folder,
video_style,
video_style_folder,
mode,
window_size,
batch_size,
tracking_window_size,
output_path,
fps,
minimum_patch_size,
num_iter,
guide_weight,
initialize,
progress = None,
):
# input
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
if len(message) > 0:
print(message)
# output
if output_path == "":
if video_style is None:
output_path = os.path.join(video_style_folder, "output")
else:
output_path = os.path.join(os.path.split(video_style)[0], "output")
os.makedirs(output_path, exist_ok=True)
print("No valid output_path. Your video will be saved here:", output_path)
elif not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
print("Your video will be saved here:", output_path)
frames_path = os.path.join(output_path, "frames")
video_path = os.path.join(output_path, "video.mp4")
os.makedirs(frames_path, exist_ok=True)
# process
if mode == "Fast" or mode == "Balanced":
tracking_window_size = 0
ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": 8,
"num_iter": num_iter,
"gpu_id": 0,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size,
}
if mode == "Fast":
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
elif mode == "Balanced":
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
elif mode == "Accurate":
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
# output
try:
fps = int(fps)
except:
fps = get_video_fps(video_style) if video_style is not None else 30
print("Fps:", fps)
print("Saving video...")
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
print("Success!")
print("Your frames are here:", frames_path)
print("Your video is here:", video_path)
return output_path, fps, video_path
class KeyFrameMatcher:
def __init__(self):
pass
def extract_number_from_filename(self, file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
if number != -1:
result.append(number)
result = tuple(result)
return result
def extract_number_from_filenames(self, file_names):
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
min_length = min(len(i) for i in numbers)
for i in range(min_length-1, -1, -1):
if len(set(number[i] for number in numbers))==len(file_names):
return [number[i] for number in numbers]
return list(range(len(file_names)))
def match_using_filename(self, file_names_a, file_names_b):
file_names_b_set = set(file_names_b)
matched_file_name = []
for file_name in file_names_a:
if file_name not in file_names_b_set:
matched_file_name.append(None)
else:
matched_file_name.append(file_name)
return matched_file_name
def match_using_numbers(self, file_names_a, file_names_b):
numbers_a = self.extract_number_from_filenames(file_names_a)
numbers_b = self.extract_number_from_filenames(file_names_b)
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
matched_file_name = []
for number in numbers_a:
if number in numbers_b_dict:
matched_file_name.append(numbers_b_dict[number])
else:
matched_file_name.append(None)
return matched_file_name
def match_filenames(self, file_names_a, file_names_b):
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
if sum([i is not None for i in matched_file_name]) > 0:
return matched_file_name
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
return matched_file_name
def detect_frames(frames_path, keyframes_path):
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
return "Please input the directory of guide video and rendered frames"
elif not os.path.exists(frames_path):
return "Please input the directory of guide video"
elif not os.path.exists(keyframes_path):
return "Please input the directory of rendered frames"
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
if len(frames)==0:
return f"No images detected in {frames_path}"
if len(keyframes)==0:
return f"No images detected in {keyframes_path}"
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
max_filename_length = max([len(i) for i in frames])
if sum([i is not None for i in matched_keyframes])==0:
message = ""
for frame, matched_keyframe in zip(frames, matched_keyframes):
message += frame + " " * (max_filename_length - len(frame) + 1)
message += "--> No matched keyframes\n"
else:
message = ""
for frame, matched_keyframe in zip(frames, matched_keyframes):
message += frame + " " * (max_filename_length - len(frame) + 1)
if matched_keyframe is None:
message += "--> [to be rendered]\n"
else:
message += f"--> {matched_keyframe}\n"
return message
def check_input_for_interpolating(frames_path, keyframes_path):
# search for images
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
# match frames
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
frames_guide = VideoData(None, frames_path)
frames_style = VideoData(None, keyframes_path, file_list=file_list)
# match shape
message = ""
height_guide, width_guide = frames_guide.shape()
height_style, width_style = frames_style.shape()
if height_guide != height_style or width_guide != width_style:
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
frames_style.set_shape(height_guide, width_guide)
return frames_guide, frames_style, index_style, message
def interpolate_video(
frames_path,
keyframes_path,
output_path,
fps,
batch_size,
tracking_window_size,
minimum_patch_size,
num_iter,
guide_weight,
initialize,
progress = None,
):
# input
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
if len(message) > 0:
print(message)
# output
if output_path == "":
output_path = os.path.join(keyframes_path, "output")
os.makedirs(output_path, exist_ok=True)
print("No valid output_path. Your video will be saved here:", output_path)
elif not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
print("Your video will be saved here:", output_path)
output_frames_path = os.path.join(output_path, "frames")
output_video_path = os.path.join(output_path, "video.mp4")
os.makedirs(output_frames_path, exist_ok=True)
# process
ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": 8,
"num_iter": num_iter,
"gpu_id": 0,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size
}
if len(index_style)==1:
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
else:
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
try:
fps = int(fps)
except:
fps = 30
print("Fps:", fps)
print("Saving video...")
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
print("Success!")
print("Your frames are here:", output_frames_path)
print("Your video is here:", video_path)
return output_path, fps, video_path
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as ui_component:
with gr.Tab("Blend"):
gr.Markdown("""
# Blend
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
""")
with gr.Row():
with gr.Column():
with gr.Tab("Guide video"):
video_guide = gr.Video(label="Guide video")
with gr.Tab("Guide video (images format)"):
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
with gr.Column():
with gr.Tab("Style video"):
video_style = gr.Video(label="Style video")
with gr.Tab("Style video (images format)"):
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
with gr.Column():
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
btn = gr.Button(value="Blend")
with gr.Row():
with gr.Column():
gr.Markdown("# Settings")
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
gr.Markdown("## Advanced Settings")
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
with gr.Column():
gr.Markdown("""
# Reference
* Output directory: the directory to save the video.
* Inference mode
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|-|-|-|-|-|-|
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
* Number of iterations: the number of iterations of patch matching. (Default: 5)
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
""")
btn.click(
smooth_video,
inputs=[
video_guide,
video_guide_folder,
video_style,
video_style_folder,
mode,
window_size,
batch_size,
tracking_window_size,
output_path,
fps,
minimum_patch_size,
num_iter,
guide_weight,
initialize
],
outputs=[output_path, fps, video_output]
)
with gr.Tab("Interpolate"):
gr.Markdown("""
# Interpolate
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
""")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
with gr.Column():
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
with gr.Row():
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
with gr.Column():
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
btn_ = gr.Button(value="Interpolate")
with gr.Row():
with gr.Column():
gr.Markdown("# Settings")
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
gr.Markdown("## Advanced Settings")
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
with gr.Column():
gr.Markdown("""
# Reference
* Output directory: the directory to save the video.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
* Number of iterations: the number of iterations of patch matching. (Default: 5)
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
""")
btn_.click(
interpolate_video,
inputs=[
video_guide_folder_,
rendered_keyframes_,
output_path_,
fps_,
batch_size_,
tracking_window_size_,
minimum_patch_size_,
num_iter_,
guide_weight_,
initialize_,
],
outputs=[output_path_, fps_, video_output_]
)
return [(ui_component, "FastBlend", "FastBlend_ui")]

View File

@@ -0,0 +1,119 @@
import cupy as cp
remapping_kernel = cp.RawKernel(r'''
extern "C" __global__
void remap(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_style,
const int* nnf,
float* target_style
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= height or y >= width) return;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
const int min_px = x < r ? -x : -r;
const int max_px = x + r > height - 1 ? height - 1 - x : r;
const int min_py = y < r ? -y : -r;
const int max_py = y + r > width - 1 ? width - 1 - y : r;
int num = 0;
for (int px = min_px; px <= max_px; px++){
for (int py = min_py; py <= max_py; py++){
const int nid = (x + px) * width + y + py;
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
num++;
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
}
}
}
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] /= num;
}
}
''', 'remap')
patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source,
const int* nnf,
const float* target,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'patch_error')
pairwise_patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void pairwise_patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_a,
const int* nnf_a,
const float* source_b,
const int* nnf_b,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
const int x_a = nnf_a[z_nnf + 0];
const int y_a = nnf_a[z_nnf + 1];
const int x_b = nnf_b[z_nnf + 0];
const int y_b = nnf_b[z_nnf + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'pairwise_patch_error')

View File

@@ -0,0 +1,146 @@
import imageio, os
import numpy as np
from PIL import Image
def read_video(file_name):
reader = imageio.get_reader(file_name)
video = []
for frame in reader:
frame = np.array(frame)
video.append(frame)
reader.close()
return video
def get_video_fps(file_name):
reader = imageio.get_reader(file_name)
fps = reader.get_meta_data()["fps"]
reader.close()
return fps
def save_video(frames_path, video_path, num_frames, fps):
writer = imageio.get_writer(video_path, fps=fps, quality=9)
for i in range(num_frames):
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
writer.append_data(frame)
writer.close()
return video_path
class LowMemoryVideo:
def __init__(self, file_name):
self.reader = imageio.get_reader(file_name)
def __len__(self):
return self.reader.count_frames()
def __getitem__(self, item):
return np.array(self.reader.get_data(item))
def __del__(self):
self.reader.close()
def split_file_name(file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
result.append(i)
if number != -1:
result.append(number)
result = tuple(result)
return result
def search_for_images(folder):
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
file_list = [i[1] for i in sorted(file_list)]
file_list = [os.path.join(folder, i) for i in file_list]
return file_list
def read_images(folder):
file_list = search_for_images(folder)
frames = [np.array(Image.open(i)) for i in file_list]
return frames
class LowMemoryImageFolder:
def __init__(self, folder, file_list=None):
if file_list is None:
self.file_list = search_for_images(folder)
else:
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
return np.array(Image.open(self.file_list[item]))
def __del__(self):
pass
class VideoData:
def __init__(self, video_file, image_folder, **kwargs):
if video_file is not None:
self.data_type = "video"
self.data = LowMemoryVideo(video_file, **kwargs)
elif image_folder is not None:
self.data_type = "images"
self.data = LowMemoryImageFolder(image_folder, **kwargs)
else:
raise ValueError("Cannot open video or image folder")
self.length = None
self.height = None
self.width = None
def raw_data(self):
frames = []
for i in range(self.__len__()):
frames.append(self.__getitem__(i))
return frames
def set_length(self, length):
self.length = length
def set_shape(self, height, width):
self.height = height
self.width = width
def __len__(self):
if self.length is None:
return len(self.data)
else:
return self.length
def shape(self):
if self.height is not None and self.width is not None:
return self.height, self.width
else:
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):
frame = self.data.__getitem__(item)
height, width, _ = frame.shape
if self.height is not None and self.width is not None:
if self.height != height or self.width != width:
frame = Image.fromarray(frame).resize((self.width, self.height))
frame = np.array(frame)
return frame
def __del__(self):
pass

View File

@@ -0,0 +1,298 @@
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
import numpy as np
import cupy as cp
import cv2
class PatchMatcher:
def __init__(
self, height, width, channel, minimum_patch_size,
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
random_search_steps=3, random_search_range=4,
use_mean_target_style=False, use_pairwise_patch_error=False,
tracking_window_size=0
):
self.height = height
self.width = width
self.channel = channel
self.minimum_patch_size = minimum_patch_size
self.threads_per_block = threads_per_block
self.num_iter = num_iter
self.gpu_id = gpu_id
self.guide_weight = guide_weight
self.random_search_steps = random_search_steps
self.random_search_range = random_search_range
self.use_mean_target_style = use_mean_target_style
self.use_pairwise_patch_error = use_pairwise_patch_error
self.tracking_window_size = tracking_window_size
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
self.pad_size = self.patch_size_list[0] // 2
self.grid = (
(height + threads_per_block - 1) // threads_per_block,
(width + threads_per_block - 1) // threads_per_block
)
self.block = (threads_per_block, threads_per_block)
def pad_image(self, image):
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
def unpad_image(self, image):
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
def apply_nnf_to_image(self, nnf, source):
batch_size = source.shape[0]
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
remapping_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
)
return target
def get_patch_error(self, source, nnf, target):
batch_size = source.shape[0]
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
patch_error_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
)
return error
def get_pairwise_patch_error(self, source, nnf):
batch_size = source.shape[0]//2
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
pairwise_patch_error_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
)
error = error.repeat(2, axis=0)
return error
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
if self.use_mean_target_style:
target_style = self.apply_nnf_to_image(nnf, source_style)
target_style = target_style.mean(axis=0, keepdims=True)
target_style = target_style.repeat(source_guide.shape[0], axis=0)
if self.use_pairwise_patch_error:
error_style = self.get_pairwise_patch_error(source_style, nnf)
else:
error_style = self.get_patch_error(source_style, nnf, target_style)
error = error_guide * self.guide_weight + error_style
return error
def clamp_bound(self, nnf):
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
return nnf
def random_step(self, nnf, r):
batch_size = nnf.shape[0]
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
upd_nnf = self.clamp_bound(nnf + step)
return upd_nnf
def neighboor_step(self, nnf, d):
if d==0:
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
upd_nnf[:, :, :, 0] += 1
elif d==1:
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
upd_nnf[:, :, :, 1] += 1
elif d==2:
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
upd_nnf[:, :, :, 0] -= 1
elif d==3:
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
upd_nnf[:, :, :, 1] -= 1
upd_nnf = self.clamp_bound(upd_nnf)
return upd_nnf
def shift_nnf(self, nnf, d):
if d>0:
d = min(nnf.shape[0], d)
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
else:
d = max(-nnf.shape[0], d)
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
return upd_nnf
def track_step(self, nnf, d):
if self.use_pairwise_patch_error:
upd_nnf = cp.zeros_like(nnf)
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
else:
upd_nnf = self.shift_nnf(nnf, d)
return upd_nnf
def C(self, n, m):
# not used
c = 1
for i in range(1, n+1):
c *= i
for i in range(1, m+1):
c //= i
for i in range(1, n-m+1):
c //= i
return c
def bezier_step(self, nnf, r):
# not used
n = r * 2 - 1
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
if d>0:
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
elif d<0:
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
return upd_nnf
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
upd_idx = (upd_err < err)
nnf[upd_idx] = upd_nnf[upd_idx]
err[upd_idx] = upd_err[upd_idx]
return nnf, err
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
for d in cp.random.permutation(4):
upd_nnf = self.neighboor_step(nnf, d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
for i in range(self.random_search_steps):
upd_nnf = self.random_step(nnf, self.random_search_range)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
for d in range(1, self.tracking_window_size + 1):
upd_nnf = self.track_step(nnf, d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
upd_nnf = self.track_step(nnf, -d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
return nnf, err
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
with cp.cuda.Device(self.gpu_id):
source_guide = self.pad_image(source_guide)
target_guide = self.pad_image(target_guide)
source_style = self.pad_image(source_style)
for it in range(self.num_iter):
self.patch_size = self.patch_size_list[it]
target_style = self.apply_nnf_to_image(nnf, source_style)
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
return nnf, target_style
class PyramidPatchMatcher:
def __init__(
self, image_height, image_width, channel, minimum_patch_size,
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
use_mean_target_style=False, use_pairwise_patch_error=False,
tracking_window_size=0,
initialize="identity"
):
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
self.pyramid_heights = []
self.pyramid_widths = []
self.patch_matchers = []
self.minimum_patch_size = minimum_patch_size
self.num_iter = num_iter
self.gpu_id = gpu_id
self.initialize = initialize
for level in range(self.pyramid_level):
height = image_height//(2**(self.pyramid_level - 1 - level))
width = image_width//(2**(self.pyramid_level - 1 - level))
self.pyramid_heights.append(height)
self.pyramid_widths.append(width)
self.patch_matchers.append(PatchMatcher(
height, width, channel, minimum_patch_size=minimum_patch_size,
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
tracking_window_size=tracking_window_size
))
def resample_image(self, images, level):
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
images = images.get()
images_resample = []
for image in images:
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
images_resample.append(image_resample)
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
return images_resample
def initialize_nnf(self, batch_size):
if self.initialize == "random":
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
nnf = cp.stack([
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
], axis=3)
elif self.initialize == "identity":
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
nnf = cp.stack([
cp.repeat(cp.arange(height), width).reshape(height, width),
cp.tile(cp.arange(width), height).reshape(height, width)
], axis=2)
nnf = cp.stack([nnf] * batch_size)
else:
raise NotImplementedError()
return nnf
def update_nnf(self, nnf, level):
# upscale
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
# check if scale is 2
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
nnf = nnf.get().astype(np.float32)
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
nnf = self.patch_matchers[level].clamp_bound(nnf)
return nnf
def apply_nnf_to_image(self, nnf, image):
with cp.cuda.Device(self.gpu_id):
image = self.patch_matchers[-1].pad_image(image)
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
return image
def estimate_nnf(self, source_guide, target_guide, source_style):
with cp.cuda.Device(self.gpu_id):
if not isinstance(source_guide, cp.ndarray):
source_guide = cp.array(source_guide, dtype=cp.float32)
if not isinstance(target_guide, cp.ndarray):
target_guide = cp.array(target_guide, dtype=cp.float32)
if not isinstance(source_style, cp.ndarray):
source_style = cp.array(source_style, dtype=cp.float32)
for level in range(self.pyramid_level):
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
source_guide_ = self.resample_image(source_guide, level)
target_guide_ = self.resample_image(target_guide, level)
source_style_ = self.resample_image(source_style, level)
nnf, target_style = self.patch_matchers[level].estimate_nnf(
source_guide_, target_guide_, source_style_, nnf
)
return nnf.get(), target_style.get()

View File

@@ -0,0 +1,4 @@
from .accurate import AccurateModeRunner
from .fast import FastModeRunner
from .balanced import BalancedModeRunner
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner

View File

@@ -0,0 +1,35 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class AccurateModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=True,
**ebsynth_config
)
# run
n = len(frames_style)
for target in tqdm(range(n), desc=desc):
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
remapped_frames = []
for i in range(l, r, batch_size):
j = min(i + batch_size, r)
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
target_guide = np.stack([frames_guide[target]] * (j - i))
source_style = np.stack([frames_style[source] for source in range(i, j)])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
remapped_frames.append(target_style)
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
frame = frame.clip(0, 255).astype("uint8")
if save_path is not None:
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

View File

@@ -0,0 +1,46 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class BalancedModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# tasks
n = len(frames_style)
tasks = []
for target in range(n):
for source in range(target - window_size, target + window_size + 1):
if source >= 0 and source < n and source != target:
tasks.append((source, target))
# run
frames = [(None, 1) for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for (source, target), result in zip(tasks_batch, target_style):
frame, weight = frames[target]
if frame is None:
frame = frames_style[target]
frames[target] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
frame = frame.clip(0, 255).astype("uint8")
if save_path is not None:
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
frames[target] = (None, 1)

View File

@@ -0,0 +1,141 @@
from ..patch_match import PyramidPatchMatcher
import functools, os
import numpy as np
from PIL import Image
from tqdm import tqdm
class TableManager:
def __init__(self):
pass
def task_list(self, n):
tasks = []
max_level = 1
while (1<<max_level)<=n:
max_level += 1
for i in range(n):
j = i
for level in range(max_level):
if i&(1<<level):
continue
j |= 1<<level
if j>=n:
break
meta_data = {
"source": i,
"target": j,
"level": level + 1
}
tasks.append(meta_data)
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
return tasks
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
n = len(frames_guide)
tasks = self.task_list(n)
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for task, result in zip(tasks_batch, target_style):
target, level = task["target"], task["level"]
if len(remapping_table[target])==level:
remapping_table[target].append((result, 1))
else:
frame, weight = remapping_table[target][level]
remapping_table[target][level] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
return remapping_table
def remapping_table_to_blending_table(self, table):
for i in range(len(table)):
for j in range(1, len(table[i])):
frame_1, weight_1 = table[i][j-1]
frame_2, weight_2 = table[i][j]
frame = (frame_1 + frame_2) / 2
weight = weight_1 + weight_2
table[i][j] = (frame, weight)
return table
def tree_query(self, leftbound, rightbound):
node_list = []
node_index = rightbound
while node_index>=leftbound:
node_level = 0
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
node_level += 1
node_list.append((node_index, node_level))
node_index -= 1<<node_level
return node_list
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
n = len(blending_table)
tasks = []
frames_result = []
for target in range(n):
node_list = self.tree_query(max(target-window_size, 0), target)
for source, level in node_list:
if source!=target:
meta_data = {
"source": source,
"target": target,
"level": level
}
tasks.append(meta_data)
else:
frames_result.append(blending_table[target][level])
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for task, frame_2 in zip(tasks_batch, target_style):
source, target, level = task["source"], task["target"], task["level"]
frame_1, weight_1 = frames_result[target]
weight_2 = blending_table[source][level][1]
weight = weight_1 + weight_2
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
frames_result[target] = (frame, weight)
return frames_result
class FastModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
frames_guide = frames_guide.raw_data()
frames_style = frames_style.raw_data()
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
if save_path is not None:
for target, frame in enumerate(frames):
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

View File

@@ -0,0 +1,121 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class InterpolationModeRunner:
def __init__(self):
pass
def get_index_dict(self, index_style):
index_dict = {}
for i, index in enumerate(index_style):
index_dict[index] = i
return index_dict
def get_weight(self, l, m, r):
weight_l, weight_r = abs(m - r), abs(m - l)
if weight_l + weight_r == 0:
weight_l, weight_r = 0.5, 0.5
else:
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
return weight_l, weight_r
def get_task_group(self, index_style, n):
task_group = []
index_style = sorted(index_style)
# first frame
if index_style[0]>0:
tasks = []
for m in range(index_style[0]):
tasks.append((index_style[0], m, index_style[0]))
task_group.append(tasks)
# middle frames
for l, r in zip(index_style[:-1], index_style[1:]):
tasks = []
for m in range(l, r):
tasks.append((l, m, r))
task_group.append(tasks)
# last frame
tasks = []
for m in range(index_style[-1], n):
tasks.append((index_style[-1], m, index_style[-1]))
task_group.append(tasks)
return task_group
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=False,
use_pairwise_patch_error=True,
**ebsynth_config
)
# task
index_dict = self.get_index_dict(index_style)
task_group = self.get_task_group(index_style, len(frames_guide))
# run
for tasks in task_group:
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide, target_guide, source_style = [], [], []
for l, m, r in tasks_batch:
# l -> m
source_guide.append(frames_guide[l])
target_guide.append(frames_guide[m])
source_style.append(frames_style[index_dict[l]])
# r -> m
source_guide.append(frames_guide[r])
target_guide.append(frames_guide[m])
source_style.append(frames_style[index_dict[r]])
source_guide = np.stack(source_guide)
target_guide = np.stack(target_guide)
source_style = np.stack(source_style)
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
if save_path is not None:
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
weight_l, weight_r = self.get_weight(l, m, r)
frame = frame_l * weight_l + frame_r * weight_r
frame = frame.clip(0, 255).astype("uint8")
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
class InterpolationModeSingleFrameRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
# check input
tracking_window_size = ebsynth_config["tracking_window_size"]
if tracking_window_size * 2 >= batch_size:
raise ValueError("batch_size should be larger than track_window_size * 2")
frame_style = frames_style[0]
frame_guide = frames_guide[index_style[0]]
patch_match_engine = PyramidPatchMatcher(
image_height=frame_style.shape[0],
image_width=frame_style.shape[1],
channel=3,
**ebsynth_config
)
# run
frame_id, n = 0, len(frames_guide)
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
if i + batch_size > n:
l, r = max(n - batch_size, 0), n
else:
l, r = i, i + batch_size
source_guide = np.stack([frame_guide] * (r-l))
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
source_style = np.stack([frame_style] * (r-l))
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for i, frame in zip(range(l, r), target_style):
if i==frame_id:
frame = frame.clip(0, 255).astype("uint8")
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
frame_id += 1
if r < n and r-frame_id <= tracking_window_size:
break

View File

@@ -0,0 +1,242 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
def warp(tenInput, tenFlow, device):
backwarp_tenGrid = {}
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
def forward(self, x, flow, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
feat = self.conv0(torch.cat((x, flow), 1))
feat = self.convblock0(feat) + feat
feat = self.convblock1(feat) + feat
feat = self.convblock2(feat) + feat
feat = self.convblock3(feat) + feat
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
return flow, mask
class IFNet(nn.Module):
def __init__(self, **kwargs):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
self.block2 = IFBlock(7+4, c=90)
self.block_tea = IFBlock(10+4, c=90)
def forward(self, x, scale_list=[4, 2, 1], training=False):
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = (x[:, :4]).detach() * 0
mask = (x[:, :1]).detach() * 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
mask = mask + (m0 + (-m1)) / 2
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2], device=x.device)
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
merged.append((warped_img0, warped_img1))
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, 1:4] * 2 - 1
'''
for i in range(3):
mask_list[i] = torch.sigmoid(mask_list[i])
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
return flow_list, mask_list[2], merged
@staticmethod
def state_dict_converter():
return IFNetStateDictConverter()
class IFNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
class RIFEInterpolater:
def __init__(self, model, device="cuda"):
self.model = model
self.device = device
# IFNet only does not support float16
self.torch_dtype = torch.float32
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_image(self, image):
width, height = image.size
if width % 32 != 0 or height % 32 != 0:
width = (width + 31) // 32
height = (height + 31) // 32
image = image.resize((width, height))
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
return image
def process_images(self, images):
images = [self.process_image(image) for image in images]
images = torch.stack(images)
return images
def decode_images(self, images):
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
images = [Image.fromarray(image) for image in images]
return images
def add_interpolated_images(self, images, interpolated_images):
output_images = []
for image, interpolated_image in zip(images, interpolated_images):
output_images.append(image)
output_images.append(interpolated_image)
output_images.append(images[-1])
return output_images
@torch.no_grad()
def interpolate_(self, images, scale=1.0):
input_tensor = self.process_images(images)
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
output_images = self.decode_images(merged[2].cpu())
if output_images[0].size != images[0].size:
output_images = [image.resize(images[0].size) for image in output_images]
return output_images
@torch.no_grad()
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
# Preprocess
processed_images = self.process_images(images)
for iter in range(num_iter):
# Input
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
# Interpolate
output_tensor = []
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
output_tensor.append(merged[2].cpu())
# Output
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
processed_images = self.add_interpolated_images(processed_images, output_tensor)
processed_images = torch.stack(processed_images)
# To images
output_images = self.decode_images(processed_images)
if output_images[0].size != images[0].size:
output_images = [image.resize(images[0].size) for image in output_images]
return output_images
class RIFESmoother(RIFEInterpolater):
def __init__(self, model, device="cuda"):
super(RIFESmoother, self).__init__(model, device=device)
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
output_tensor = []
for batch_id in range(0, input_tensor.shape[0], batch_size):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
output_tensor.append(merged[2].cpu())
output_tensor = torch.concat(output_tensor, dim=0)
return output_tensor
@torch.no_grad()
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
# Preprocess
processed_images = self.process_images(rendered_frames)
for iter in range(num_iter):
# Input
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
# Interpolate
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
# Blend
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
# Add to frames
processed_images[1:-1] = output_tensor
# To images
output_images = self.decode_images(processed_images)
if output_images[0].size != rendered_frames[0].size:
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
return output_images

View File

@@ -0,0 +1 @@
from .model_manager import *

View File

@@ -0,0 +1,89 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size = q.shape[0]
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)

395
diffsynth/models/cog_dit.py Normal file
View File

@@ -0,0 +1,395 @@
import torch
from einops import rearrange, repeat
from .sd3_dit import TimestepEmbeddings
from .attention import Attention
from .utils import load_state_dict_from_folder
from .tiler import TileWorker2Dto3D
import numpy as np
class CogPatchify(torch.nn.Module):
def __init__(self, dim_in, dim_out, patch_size) -> None:
super().__init__()
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
return hidden_states
class CogAdaLayerNorm(torch.nn.Module):
def __init__(self, dim, dim_cond, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
def forward(self, hidden_states, prompt_emb, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
return hidden_states
else:
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
return hidden_states, prompt_emb, gate_a, gate_b
class CogDiTBlock(torch.nn.Module):
def __init__(self, dim, dim_cond, num_heads):
super().__init__()
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def apply_rotary_emb(self, x, freqs_cis):
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
q = self.norm_q(q)
k = self.norm_k(k)
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
return q, k, v
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
# Attention
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
hidden_states, prompt_emb, time_emb
)
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attention_io = self.attn1(
attention_io,
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
)
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
# Feed forward
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
hidden_states, prompt_emb, time_emb
)
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_io = self.ff(ff_io)
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
return hidden_states, prompt_emb
class CogDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.patchify = CogPatchify(16, 3072, 2)
self.time_embedder = TimestepEmbeddings(3072, 512)
self.context_embedder = torch.nn.Linear(4096, 3072)
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_3d_rotary_pos_embed(
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
):
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
):
grid_height = height // 2
grid_width = width // 2
base_size_width = 720 // (8 * 2)
base_size_height = 480 // (8 * 2)
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
embed_dim=64,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def build_mask(self, T, H, W, dtype, device, is_bound):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
B, C, T, H, W = hidden_states.shape
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = max(H - tile_size, 0), H
if w_ > W: w, w_ = max(W - tile_size, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
mask = self.build_mask(
value.shape[2], (hr-hl), (wr-wl),
hidden_states.dtype, hidden_states.device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
)
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
value[:, :, :, hl:hr, wl:wr] += model_output * mask
weight[:, :, :, hl:hr, wl:wr] += mask
value = value / weight
return value
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
if tiled:
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
model_input=hidden_states,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
)
num_frames, height, width = hidden_states.shape[-3:]
if image_rotary_emb is None:
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
hidden_states = self.patchify(hidden_states)
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return CogDiTStateDictConverter()
@staticmethod
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
model = CogDiT().to(torch_dtype)
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
model.load_state_dict(state_dict)
return model
class CogDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"patch_embed.proj.weight": "patchify.proj.weight",
"patch_embed.proj.bias": "patchify.proj.bias",
"patch_embed.text_proj.weight": "context_embedder.weight",
"patch_embed.text_proj.bias": "context_embedder.bias",
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
"norm_final.weight": "norm_final.weight",
"norm_final.bias": "norm_final.bias",
"norm_out.linear.weight": "norm_out.linear.weight",
"norm_out.linear.bias": "norm_out.linear.bias",
"norm_out.norm.weight": "norm_out.norm.weight",
"norm_out.norm.bias": "norm_out.norm.bias",
"proj_out.weight": "proj_out.weight",
"proj_out.bias": "proj_out.bias",
}
suffix_dict = {
"norm1.linear.weight": "norm1.linear.weight",
"norm1.linear.bias": "norm1.linear.bias",
"norm1.norm.weight": "norm1.norm.weight",
"norm1.norm.bias": "norm1.norm.bias",
"attn1.norm_q.weight": "norm_q.weight",
"attn1.norm_q.bias": "norm_q.bias",
"attn1.norm_k.weight": "norm_k.weight",
"attn1.norm_k.bias": "norm_k.bias",
"attn1.to_q.weight": "attn1.to_q.weight",
"attn1.to_q.bias": "attn1.to_q.bias",
"attn1.to_k.weight": "attn1.to_k.weight",
"attn1.to_k.bias": "attn1.to_k.bias",
"attn1.to_v.weight": "attn1.to_v.weight",
"attn1.to_v.bias": "attn1.to_v.bias",
"attn1.to_out.0.weight": "attn1.to_out.weight",
"attn1.to_out.0.bias": "attn1.to_out.bias",
"norm2.linear.weight": "norm2.linear.weight",
"norm2.linear.bias": "norm2.linear.bias",
"norm2.norm.weight": "norm2.norm.weight",
"norm2.norm.bias": "norm2.norm.bias",
"ff.net.0.proj.weight": "ff.0.weight",
"ff.net.0.proj.bias": "ff.0.bias",
"ff.net.2.weight": "ff.2.weight",
"ff.net.2.bias": "ff.2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "patch_embed.proj.weight":
param = param.unsqueeze(2)
state_dict_[rename_dict[name]] = param
else:
names = name.split(".")
if names[0] == "transformer_blocks":
suffix = ".".join(names[2:])
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

518
diffsynth/models/cog_vae.py Normal file
View File

@@ -0,0 +1,518 @@
import torch
from einops import rearrange, repeat
from .tiler import TileWorker2Dto3D
class Downsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
class CogVideoXSpatialNorm3D(torch.nn.Module):
def __init__(self, f_channels, zq_channels, groups):
super().__init__()
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class Resnet3DBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
super().__init__()
self.nonlinearity = torch.nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
if in_channels != out_channels:
if use_conv_shortcut:
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
else:
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.conv_shortcut = lambda x: x
def forward(self, hidden_states, zq):
residual = hidden_states
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + self.conv_shortcut(residual)
return hidden_states
class CachedConv3d(torch.nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.cached_tensor = None
def clear_cache(self):
self.cached_tensor = None
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
if use_cache:
if self.cached_tensor is None:
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
input = torch.concat([self.cached_tensor, input], dim=2)
self.cached_tensor = input[:, :, -2:]
return super().forward(input)
class CogVAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Upsample3D(512, 512, compress_time=True),
Resnet3DBlock(512, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
])
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
sample = sample / self.scaling_factor
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.decode_small_video(x),
model_input=sample,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
progress_bar=progress_bar
)
else:
return self.decode_small_video(sample)
def decode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//2):
tl = i*2 + T%2 - (T%2 and i==0)
tr = i*2 + 2 + T%2
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEDecoderStateDictConverter()
class CogVAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Downsample3D(128, 128, compress_time=True),
Resnet3DBlock(128, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
])
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)[:, :16]
hidden_states = hidden_states * self.scaling_factor
return hidden_states
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.encode_small_video(x),
model_input=sample,
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
progress_bar=progress_bar
)
else:
return self.encode_small_video(sample)
def encode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//8):
t = i*8 + T%2 - (T%2 and i==0)
t_ = i*8 + 8 + T%2
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEEncoderStateDictConverter()
class CogVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"encoder.conv_in.conv.weight": "conv_in.weight",
"encoder.conv_in.conv.bias": "conv_in.bias",
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
"encoder.norm_out.weight": "norm_out.weight",
"encoder.norm_out.bias": "norm_out.bias",
"encoder.conv_out.conv.weight": "conv_out.weight",
"encoder.conv_out.conv.bias": "conv_out.bias",
}
prefix_dict = {
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
"encoder.mid_block.resnets.0.": "blocks.15.",
"encoder.mid_block.resnets.1.": "blocks.16.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
"norm1.weight": "norm1.weight",
"norm1.bias": "norm1.bias",
"norm2.weight": "norm2.weight",
"norm2.bias": "norm2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class CogVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"decoder.conv_in.conv.weight": "conv_in.weight",
"decoder.conv_in.conv.bias": "conv_in.bias",
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
"decoder.conv_out.conv.weight": "conv_out.weight",
"decoder.conv_out.conv.bias": "conv_out.bias"
}
prefix_dict = {
"decoder.mid_block.resnets.0.": "blocks.0.",
"decoder.mid_block.resnets.1.": "blocks.1.",
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,96 +0,0 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
config = DINOv3ViTConfig(
architectures = [
"DINOv3ViTModel"
],
attention_dropout = 0.0,
drop_path_rate = 0.0,
dtype = "float32",
hidden_act = "silu",
hidden_size = 4096,
image_size = 224,
initializer_range = 0.02,
intermediate_size = 8192,
key_bias = False,
layer_norm_eps = 1e-05,
layerscale_value = 1.0,
mlp_bias = True,
model_type = "dinov3_vit",
num_attention_heads = 32,
num_channels = 3,
num_hidden_layers = 40,
num_register_tokens = 4,
patch_size = 16,
pos_embed_jitter = None,
pos_embed_rescale = 2.0,
pos_embed_shift = None,
proj_bias = True,
query_bias = False,
rope_theta = 100.0,
transformers_version = "4.56.1",
use_gated_mlp = True,
value_bias = False
)
super().__init__(config)
self.processor = DINOv3ViTImageProcessorFast(
crop_size = None,
data_format = "channels_first",
default_to_square = True,
device = None,
disable_grouping = None,
do_center_crop = None,
do_convert_rgb = None,
do_normalize = True,
do_rescale = True,
do_resize = True,
image_mean = [
0.485,
0.456,
0.406
],
image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [
0.229,
0.224,
0.225
],
input_data_format = None,
resample = 2,
rescale_factor = 0.00392156862745098,
return_tensors = None,
size = {
"height": 224,
"width": 224
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None
head_mask = None
pixel_values = pixel_values.to(torch_dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
hidden_states,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
sequence_output = self.norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return pooled_output

View File

@@ -0,0 +1,66 @@
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files

File diff suppressed because it is too large Load Diff

View File

@@ -1,58 +0,0 @@
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
def __init__(self):
config = Mistral3Config(**{
"architectures": [
"Mistral3ForConditionalGeneration"
],
"dtype": "bfloat16",
"image_token_index": 10,
"model_type": "mistral3",
"multimodal_projector_bias": False,
"projector_hidden_act": "gelu",
"spatial_merge_size": 2,
"text_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 32768,
"max_position_embeddings": 131072,
"model_type": "mistral",
"num_attention_heads": 32,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"use_cache": True,
"vocab_size": 131072
},
"transformers_version": "4.57.1",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 1024,
"image_size": 1540,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "pixtral",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"rope_theta": 10000.0
},
"vision_feature_layer": -1
})
super().__init__(config)
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -1,384 +0,0 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
# from .utils import hash_state_dict_keys, init_weights_on_device
from contextlib import contextmanager
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
# @staticmethod
# def state_dict_converter():
# return FluxControlNetStateDictConverter()
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class QLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class QRMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
class QEmbedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight= cast_weight(self,input)
return torch.nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module,quantized_layer.QRMSNorm):
continue
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias = module.bias
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.QRMSNorm(module)
setattr(model, name, new_layer)
elif isinstance(module,torch.nn.Embedding):
rows, cols = module.weight.shape
new_layer = quantized_layer.QEmbedding(
num_embeddings=rows,
embedding_dim=cols,
_weight=module.weight,
# _freeze=module.freeze,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,15 +1,9 @@
import torch
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from einops import rearrange
from .tiler import TileWorker
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size, num_tokens = hidden_states.shape[0:2]
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
class RoPEEmbedding(torch.nn.Module):
def __init__(self, dim, theta, axes_dim):
@@ -39,8 +33,23 @@ class RoPEEmbedding(torch.nn.Module):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
return hidden_states
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
@@ -69,7 +78,8 @@ class FluxJointAttention(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
batch_size = hidden_states_a.shape[0]
# Part A
@@ -90,19 +100,17 @@ class FluxJointAttention(torch.nn.Module):
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
if ipadapter_kwargs_list is not None:
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
@@ -128,12 +136,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -146,7 +154,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
@@ -183,7 +191,7 @@ class FluxSingleAttention(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
@@ -199,7 +207,7 @@ class AdaLayerNormSingle(torch.nn.Module):
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
@@ -224,8 +232,8 @@ class FluxSingleTransformerBlock(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
def process_attention(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -234,29 +242,27 @@ class FluxSingleTransformerBlock(torch.nn.Module):
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
attn_output = self.process_attention(attn_output, image_rotary_emb)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
@@ -268,29 +274,27 @@ class AdaLayerNormContinuous(torch.nn.Module):
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
shift, scale = torch.chunk(emb, 2, dim=1)
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
def __init__(self):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.guidance_embedder = TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(input_dim, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
self.input_dim = input_dim
def patchify(self, hidden_states):
@@ -301,7 +305,7 @@ class FluxDiT(torch.nn.Module):
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
@@ -318,78 +322,272 @@ class FluxDiT(torch.nn.Module):
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
total_seq_len = N * prompt_seq_len + image_seq_len
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
image_start = N * prompt_seq_len
image_end = N * prompt_seq_len + image_seq_len
# prompt-image mask
for i in range(N):
prompt_start = i * prompt_seq_len
prompt_end = (i + 1) * prompt_seq_len
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt mask
for i in range(N):
for j in range(N):
if i != j:
prompt_start_i = i * prompt_seq_len
prompt_end_i = (i + 1) * prompt_seq_len
prompt_start_j = j * prompt_seq_len
prompt_end_j = (j + 1) * prompt_seq_len
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
return prompt_emb, image_rotary_emb, attention_mask
def tiled_forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=128, tile_stride=64,
**kwargs
):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
tiled=False, tile_size=128, tile_stride=64,
use_gradient_checkpointing=False,
**kwargs
):
# (Deprecated) The real forward is in `pipelines.flux_image`.
return None
if tiled:
return self.tiled_forward(
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
"txt_in.bias": "context_embedder.bias",
"txt_in.weight": "context_embedder.weight",
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
"img_attn.proj.bias": "attn.a_to_out.bias",
"img_attn.proj.weight": "attn.a_to_out.weight",
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
"img_mlp.0.bias": "ff_a.0.bias",
"img_mlp.0.weight": "ff_a.0.weight",
"img_mlp.2.bias": "ff_a.2.bias",
"img_mlp.2.weight": "ff_a.2.weight",
"img_mod.lin.bias": "norm1_a.linear.bias",
"img_mod.lin.weight": "norm1_a.linear.weight",
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
"txt_attn.proj.bias": "attn.b_to_out.bias",
"txt_attn.proj.weight": "attn.b_to_out.weight",
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
"txt_mlp.0.bias": "ff_b.0.bias",
"txt_mlp.0.weight": "ff_b.0.weight",
"txt_mlp.2.bias": "ff_b.2.bias",
"txt_mlp.2.weight": "ff_b.2.weight",
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",
"modulation.lin.weight": "norm.linear.weight",
"norm.key_norm.scale": "norm_k_a.weight",
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
if name.startswith("final_layer.adaLN_modulation.1."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[rename] = param
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
else:
pass
return state_dict_

View File

@@ -1,129 +0,0 @@
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class InfiniteYouImageProjector(nn.Module):
def __init__(
self,
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=8,
embedding_dim=512,
output_dim=4096,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
latents = latents.to(dtype=x.dtype, device=x.device)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
@staticmethod
def state_dict_converter():
return FluxInfiniteYouImageProjectorStateDictConverter()
class FluxInfiniteYouImageProjectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict['image_proj']

View File

@@ -1,110 +0,0 @@
from .general_modules import RMSNorm
from transformers import SiglipVisionModel, SiglipVisionConfig
import torch
class SiglipVisionModelSO400M(SiglipVisionModel):
def __init__(self):
config = SiglipVisionConfig(
hidden_size=1152,
image_size=384,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
patch_size=14,
architectures=["SiglipModel"],
initializer_factor=1.0,
torch_dtype="float32",
transformers_version="4.37.0.dev0"
)
super().__init__(config)
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class IpAdapterModule(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = attention_head_dim
output_dim = num_attention_heads * attention_head_dim
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
def forward(self, hidden_states):
batch_size = hidden_states.shape[0]
# ip_k
ip_k = self.to_k_ip(hidden_states)
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_k = self.norm_added_k(ip_k)
# ip_v
ip_v = self.to_v_ip(hidden_states)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return ip_k, ip_v
class FluxIpAdapter(torch.nn.Module):
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
super().__init__()
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
self.set_adapter()
def set_adapter(self):
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
ip_kv_dict = {}
for block_id in self.call_block_id:
ipadapter_id = self.call_block_id[block_id]
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
ip_kv_dict[block_id] = {
"ip_k": ip_k,
"ip_v": ip_v,
"scale": scale
}
return ip_kv_dict
@staticmethod
def state_dict_converter():
return FluxIpAdapterStateDictConverter()
class FluxIpAdapterStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict["ip_adapter"]:
name_ = 'ipadapter_modules.' + name
state_dict_[name_] = state_dict["ip_adapter"][name]
for name in state_dict["image_proj"]:
name_ = "image_proj." + name
state_dict_[name_] = state_dict["image_proj"][name]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,521 +0,0 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size = q.shape[0]
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask=None):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SDTextEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=1):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
break
embeds = self.final_layer_norm(embeds)
return embeds
@staticmethod
def state_dict_converter():
return SDTextEncoderStateDictConverter()
class SDTextEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
return state_dict_
class LoRALayerBlock(torch.nn.Module):
def __init__(self, L, dim_in, dim_out):
super().__init__()
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, lora_A, lora_B):
x = self.x @ lora_A.T @ lora_B.T
x = self.layer_norm(x)
return x
class LoRAEmbedder(torch.nn.Module):
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
self.model_dict = torch.nn.ModuleDict(model_dict)
proj_dict = {}
for lora_pattern in lora_patterns:
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
if layer_type not in proj_dict:
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
self.proj_dict = torch.nn.ModuleDict(proj_dict)
self.lora_patterns = lora_patterns
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
return lora_patterns
def forward(self, lora):
lora_emb = []
for lora_pattern in self.lora_patterns:
name, layer_type = lora_pattern["name"], lora_pattern["type"]
lora_A = lora[name + ".lora_A.weight"]
lora_B = lora[name + ".lora_B.weight"]
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
lora_emb.append(lora_out)
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
class FluxLoRAEncoder(torch.nn.Module):
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
super().__init__()
self.num_embeds_per_lora = num_embeds_per_lora
# embedder
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
# special embedding
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
self.num_special_embeds = num_special_embeds
# final layer
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, lora):
lora_embeds = self.embedder(lora)
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds)
embeds = embeds[:, :self.num_special_embeds]
embeds = self.final_layer_norm(embeds)
embeds = self.final_linear(embeds)
return embeds
@staticmethod
def state_dict_converter():
return FluxLoRAEncoderStateDictConverter()
class FluxLoRAEncoderStateDictConverter:
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,306 +0,0 @@
import torch, math
from ..core.loader import load_state_dict
from typing import Union
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
updated_num = 0
lora_name_dict = self.get_name_dict(state_dict_lora)
for name, module in model.named_modules():
if name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
state_dict = module.state_dict()
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
module.load_state_dict(state_dict)
updated_num += 1
print(f"{updated_num} tensors are updated by LoRA.")
class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype)
self.diffusers_rename_dict = {
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
}
self.civitai_rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
super().load(model, state_dict_lora, alpha)
def convert_state_dict(self,state_dict):
def guess_block_id(name,model_resource):
if model_resource == 'civitai':
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
if model_resource == 'diffusers':
names = name.split(".")
for i in names:
if i.isdigit():
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
return None, None
def guess_resource(state_dict):
for k in state_dict:
if "lora_unet_" in k:
return 'civitai'
elif k.startswith("transformer."):
return 'diffusers'
else:
None
model_resource = guess_resource(state_dict)
if model_resource is None:
return state_dict
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
def guess_alpha(state_dict):
for name, param in state_dict.items():
if ".alpha" in name:
for suffix in [".lora_down.weight", ".lora_A.weight"]:
name_ = name.replace(".alpha", suffix)
if name_ in state_dict:
lora_alpha = param.item() / state_dict[name_].shape[0]
lora_alpha = math.sqrt(lora_alpha)
return lora_alpha
return 1
alpha = guess_alpha(state_dict)
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name,model_resource)
if alpha != 1:
param *= alpha
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
if model_resource == 'diffusers':
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
dim = 4
if 'lora_A' in name:
dim = 1
mlp = torch.zeros(dim * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
if 'lora_A' in name:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
elif 'lora_B' in name:
d, r = state_dict_[name].shape
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
param[:d, :r] = state_dict_.pop(name)
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
param[3*d:, 3*r:] = mlp
else:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
concat_dim = 0
if 'lora_A' in name:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
elif 'lora_B' in name:
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
d, r = origin.shape
# print(d, r)
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
else:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class FluxLoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)

View File

@@ -0,0 +1,93 @@
import torch
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FluxTextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return embeds, pooled_embeds
@staticmethod
def state_dict_converter():
return FluxTextEncoder1StateDictConverter()
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
super().__init__(config)
self.eval()
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb
@staticmethod
def state_dict_converter():
return FluxTextEncoder2StateDictConverter()
class FluxTextEncoder1StateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = state_dict
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,112 +0,0 @@
import torch
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask=None):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class FluxTextEncoderClip(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids)
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds, hidden_states

View File

@@ -1,43 +0,0 @@
import torch
from transformers import T5EncoderModel, T5Config
class FluxTextEncoderT5(T5EncoderModel):
def __init__(self):
config = T5Config(**{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"dtype": "bfloat16",
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": True,
"is_gated_act": True,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": True,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": False,
"transformers_version": "4.57.1",
"use_cache": True,
"vocab_size": 32128
})
super().__init__(config)
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb

View File

@@ -1,451 +1,303 @@
import torch
from einops import rearrange, repeat
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
class TileWorker:
class FluxVAEEncoder(SD3VAEEncoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEEncoderStateDictConverter()
class FluxVAEDecoder(SD3VAEDecoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEDecoderStateDictConverter()
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
pass
def mask(self, height, width, border_width):
# Create a mask with shape (height, width).
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / border_width).clip(0, 1)
return mask
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
batch_size, channel, _, _ = model_input.shape
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
unfold_operator = torch.nn.Unfold(
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
model_input = unfold_operator(model_input)
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
return model_input
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
# Call y=forward_fn(x) for each tile
tile_num = model_input.shape[-1]
model_output_stack = []
for tile_id in range(0, tile_num, tile_batch_size):
# process input
tile_id_ = min(tile_id + tile_batch_size, tile_num)
x = model_input[:, :, :, :, tile_id: tile_id_]
x = x.to(device=inference_device, dtype=inference_dtype)
x = rearrange(x, "b c h w n -> (n b) c h w")
# process output
y = forward_fn(x)
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
y = y.to(device=tile_device, dtype=tile_dtype)
model_output_stack.append(y)
model_output = torch.concat(model_output_stack, dim=-1)
return model_output
def io_scale(self, model_output, tile_size):
# Determine the size modification happened in forward_fn
# We only consider the same scale on height and width.
io_scale = model_output.shape[2] / tile_size
return io_scale
def from_civitai(self, state_dict):
rename_dict = {
"encoder.conv_in.bias": "conv_in.bias",
"encoder.conv_in.weight": "conv_in.weight",
"encoder.conv_out.bias": "conv_out.bias",
"encoder.conv_out.weight": "conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"encoder.norm_out.bias": "conv_norm_out.bias",
"encoder.norm_out.weight": "conv_norm_out.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
# The reversed function of tile
mask = self.mask(tile_size, tile_size, border_width)
mask = mask.to(device=tile_device, dtype=tile_dtype)
mask = rearrange(mask, "h w -> 1 1 h w 1")
model_output = model_output * mask
fold_operator = torch.nn.Fold(
output_size=(height, width),
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
model_output = fold_operator(model_output) / fold_operator(mask)
return model_output
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
inference_device, inference_dtype = model_input.device, model_input.dtype
height, width = model_input.shape[2], model_input.shape[3]
border_width = int(tile_stride*0.5) if border_width is None else border_width
# tile
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
# inference
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
# resize
io_scale = self.io_scale(model_output, tile_size)
height, width = int(height*io_scale), int(width*io_scale)
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
border_width = int(border_width*io_scale)
# untile
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output
class ConvAttention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
q = self.to_q(conv_input)
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
k = self.to_k(conv_input)
v = self.to_v(conv_input)
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
hidden_states = self.to_out(conv_input)
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
return hidden_states
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class VAEAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
if use_conv_attention:
self.transformer_blocks = torch.nn.ModuleList([
ConvAttention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
else:
self.transformer_blocks = torch.nn.ModuleList([
Attention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
def forward(self, hidden_states, time_emb, text_emb, res_stack):
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class ResnetBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
super().__init__()
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = torch.nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
x = hidden_states
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if time_emb is not None:
emb = self.nonlinearity(time_emb)
emb = self.time_emb_proj(emb)[:, :, None, None]
x = x + emb
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
hidden_states = self.conv_shortcut(hidden_states)
hidden_states = hidden_states + x
return hidden_states, time_emb, text_emb, res_stack
class UpSampler(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
class DownSampler(torch.nn.Module):
def __init__(self, channels, padding=1, extra_padding=False):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
self.extra_padding = extra_padding
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
if self.extra_padding:
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
class FluxVAEDecoder(torch.nn.Module):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
UpSampler(256),
# UpDecoderBlock2D
ResnetBlock(256, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = sample / self.scaling_factor + self.shift_factor
hidden_states = self.conv_in(hidden_states)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class FluxVAEEncoder(torch.nn.Module):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# DownEncoderBlock2D
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
DownSampler(128, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(128, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
DownSampler(256, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(256, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
DownSampler(512, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = self.conv_in(sample)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states[:, :16]
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
return hidden_states
def encode_video(self, sample, batch_size=8):
B = sample.shape[0]
hidden_states = []
for i in range(0, sample.shape[2], batch_size):
j = min(i + batch_size, sample.shape[2])
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
hidden_states_batch = self(sample_batch)
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
hidden_states.append(hidden_states_batch)
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"decoder.conv_in.bias": "conv_in.bias",
"decoder.conv_in.weight": "conv_in.weight",
"decoder.conv_out.bias": "conv_out.bias",
"decoder.conv_out.weight": "conv_out.weight",
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"decoder.norm_out.bias": "conv_norm_out.bias",
"decoder.norm_out.weight": "conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -1,56 +0,0 @@
import torch
from .general_modules import TemporalTimesteps
class MultiValueEncoder(torch.nn.Module):
def __init__(self, encoders=()):
super().__init__()
if not isinstance(encoders, list):
encoders = [encoders]
self.encoders = torch.nn.ModuleList(encoders)
def __call__(self, values, dtype):
emb = []
for encoder, value in zip(self.encoders, values):
if value is not None:
value = value.unsqueeze(0)
emb.append(encoder(value, dtype))
emb = torch.concat(emb, dim=0)
return emb
class SingleValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
super().__init__()
self.prefer_len = prefer_len
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.prefer_value_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.positional_embedding = torch.nn.Parameter(
torch.randn(self.prefer_len, dim_out)
)
def forward(self, value, dtype):
value = value * 1000
emb = self.prefer_proj(value).to(dtype)
emb = self.prefer_value_embedder(emb).squeeze(0)
base_embeddings = emb.expand(self.prefer_len, -1)
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
learned_embeddings = base_embeddings + positional_embedding
return learned_embeddings
@staticmethod
def state_dict_converter():
return SingleValueEncoderStateDictConverter()
class SingleValueEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,146 +0,0 @@
import torch, math
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
computation_device = None,
align_dtype_to_timestep = False,
):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.computation_device = computation_device
self.scale = scale
self.align_dtype_to_timestep = align_dtype_to_timestep
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
computation_device=self.computation_device,
scale=self.scale,
align_dtype_to_timestep=self.align_dtype_to_timestep,
)
return t_emb
class DiffusersCompatibleTimestepProj(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
self.act = torch.nn.SiLU()
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
def forward(self, x):
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format:
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
else:
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps, elementwise_affine=True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states
class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False, dual=False):
super().__init__()
self.single = single
self.dual = dual
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift
return x
elif self.dual:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
norm_x = self.norm(x)
x = norm_x * (1 + scale_msa) + shift_msa
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp

View File

@@ -0,0 +1,451 @@
from .attention import Attention
from einops import repeat, rearrange
import math
import torch
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
super().__init__()
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
self.rotary_emb_on_k = rotary_emb_on_k
self.k_cache, self.v_cache = [], []
def reshape_for_broadcast(self, freqs_cis, x):
ndim = x.ndim
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(self, x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(self, xq, xk, freqs_cis):
xk_out = None
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
# norm
q = self.q_norm(q)
k = self.k_norm(k)
# RoPE
if self.rotary_emb_on_k:
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
else:
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
if to_cache:
self.k_cache.append(k)
self.v_cache.append(v)
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
k = torch.concat([k] + self.k_cache, dim=2)
v = torch.concat([v] + self.v_cache, dim=2)
self.k_cache, self.v_cache = [], []
return q, k, v
class FP32_Layernorm(torch.nn.LayerNorm):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
class FP32_SiLU(torch.nn.SiLU):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
class HunyuanDiTFinalLayer(torch.nn.Module):
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
super().__init__()
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = torch.nn.Sequential(
FP32_SiLU(),
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
)
def modulate(self, x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def forward(self, hidden_states, condition_emb):
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
hidden_states = self.linear(hidden_states)
return hidden_states
class HunyuanDiTBlock(torch.nn.Module):
def __init__(
self,
hidden_dim=1408,
condition_dim=1408,
num_heads=16,
mlp_ratio=4.3637,
text_dim=1024,
skip_connection=False
):
super().__init__()
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
)
if skip_connection:
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
else:
self.skip_norm, self.skip_linear = None, None
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
# Long Skip Connection
if self.skip_norm is not None and self.skip_linear is not None:
hidden_states = torch.cat([hidden_states, residual], dim=-1)
hidden_states = self.skip_norm(hidden_states)
hidden_states = self.skip_linear(hidden_states)
# Self-Attention
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
attn_input = self.norm1(hidden_states) + shift_msa
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
# Cross-Attention
attn_input = self.norm3(hidden_states)
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
# FFN Layer
mlp_input = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class AttentionPool(torch.nn.Module):
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
super().__init__()
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = torch.nn.functional.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class PatchEmbed(torch.nn.Module):
def __init__(
self,
patch_size=(2, 2),
in_chans=4,
embed_dim=1408,
bias=True,
):
super().__init__()
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(t, "b -> b d", d=dim)
return embedding
class TimestepEmbedder(torch.nn.Module):
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class HunyuanDiT(torch.nn.Module):
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
super().__init__()
# Embedders
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
self.t5_embedder = torch.nn.Sequential(
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
FP32_SiLU(),
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
)
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
self.patch_embedder = PatchEmbed(in_chans=in_channels)
self.timestep_embedder = TimestepEmbedder()
self.extra_embedder = torch.nn.Sequential(
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
FP32_SiLU(),
torch.nn.Linear(hidden_dim * 4, hidden_dim),
)
# Transformer blocks
self.num_layers_down = num_layers_down
self.num_layers_up = num_layers_up
self.blocks = torch.nn.ModuleList(
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
)
# Output layers
self.final_layer = HunyuanDiTFinalLayer()
self.out_channels = out_channels
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
text_emb_mask = text_emb_mask.bool()
text_emb_mask_t5 = text_emb_mask_t5.bool()
text_emb_t5 = self.t5_embedder(text_emb_t5)
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
return text_emb
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
# Text embedding
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
# Timestep embedding
timestep_emb = self.timestep_embedder(timestep)
# Size embedding
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
size_emb = size_emb.view(-1, 6 * 256)
# Style embedding
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
# Concatenate all extra vectors
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
return condition_emb
def unpatchify(self, x, h, w):
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
B, C, H, W = hidden_states.shape
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
if residual is not None:
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
else:
residual_batch = None
# Forward
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
def forward(
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
tiled=False, tile_size=64, tile_stride=32,
to_cache=False,
use_gradient_checkpointing=False,
):
# Embeddings
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
# Input
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
hidden_states = self.patch_embedder(hidden_states)
# Blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tiled:
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
hidden_states = self.tiled_block_forward(
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
tile_size=tile_size, tile_stride=tile_stride
)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
else:
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
if self.training and use_gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
use_reentrant=False,
)
else:
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
# Output
hidden_states = self.final_layer(hidden_states, condition_emb)
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
@staticmethod
def state_dict_converter():
return HunyuanDiTStateDictConverter()
class HunyuanDiTStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
name_ = name
name_ = name_.replace(".default_modulation.", ".modulation.")
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
name_ = name_.replace(".q_proj.", ".to_q.")
name_ = name_.replace(".out_proj.", ".to_out.")
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
name_ = name_.replace("pooler.", "t5_pooler.")
name_ = name_.replace("x_embedder.", "patch_embedder.")
name_ = name_.replace("t_embedder.", "timestep_embedder.")
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
name_ = name_.replace("style_embedder.weight", "style_embedder")
if ".kv_proj." in name_:
param_k = param[:param.shape[0]//2]
param_v = param[param.shape[0]//2:]
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
elif ".Wqkv." in name_:
param_q = param[:param.shape[0]//3]
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
param_v = param[param.shape[0]//3*2:]
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
elif "style_embedder" in name_:
state_dict_[name_] = param.squeeze()
else:
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,163 @@
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
import torch
class HunyuanDiTCLIPTextEncoder(BertModel):
def __init__(self):
config = BertConfig(
_name_or_path = "",
architectures = ["BertModel"],
attention_probs_dropout_prob = 0.1,
bos_token_id = 0,
classifier_dropout = None,
directionality = "bidi",
eos_token_id = 2,
hidden_act = "gelu",
hidden_dropout_prob = 0.1,
hidden_size = 1024,
initializer_range = 0.02,
intermediate_size = 4096,
layer_norm_eps = 1e-12,
max_position_embeddings = 512,
model_type = "bert",
num_attention_heads = 16,
num_hidden_layers = 24,
output_past = True,
pad_token_id = 0,
pooler_fc_size = 768,
pooler_num_attention_heads = 12,
pooler_num_fc_layers = 3,
pooler_size_per_head = 128,
pooler_type = "first_token_transform",
position_embedding_type = "absolute",
torch_dtype = "float32",
transformers_version = "4.37.2",
type_vocab_size = 2,
use_cache = True,
vocab_size = 47020
)
super().__init__(config, add_pooling_layer=False)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
past_key_values_length = 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
all_hidden_states = encoder_outputs.hidden_states
prompt_emb = all_hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
@staticmethod
def state_dict_converter():
return HunyuanDiTCLIPTextEncoderStateDictConverter()
class HunyuanDiTT5TextEncoder(T5EncoderModel):
def __init__(self):
config = T5Config(
_name_or_path = "../HunyuanDiT/t2i/mt5",
architectures = ["MT5ForConditionalGeneration"],
classifier_dropout = 0.0,
d_ff = 5120,
d_kv = 64,
d_model = 2048,
decoder_start_token_id = 0,
dense_act_fn = "gelu_new",
dropout_rate = 0.1,
eos_token_id = 1,
feed_forward_proj = "gated-gelu",
initializer_factor = 1.0,
is_encoder_decoder = True,
is_gated_act = True,
layer_norm_epsilon = 1e-06,
model_type = "t5",
num_decoder_layers = 24,
num_heads = 32,
num_layers = 24,
output_past = True,
pad_token_id = 0,
relative_attention_max_distance = 128,
relative_attention_num_buckets = 32,
tie_word_embeddings = False,
tokenizer_class = "T5Tokenizer",
transformers_version = "4.37.2",
use_cache = True,
vocab_size = 250112
)
super().__init__(config)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
prompt_emb = outputs.hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
@staticmethod
def state_dict_converter():
return HunyuanDiTT5TextEncoderStateDictConverter()
class HunyuanDiTCLIPTextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class HunyuanDiTT5TextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
state_dict_["shared.weight"] = state_dict["shared.weight"]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

File diff suppressed because one or more lines are too long

View File

@@ -1,902 +0,0 @@
from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.amp as amp
import numpy as np
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward
class RMSNorm_FP32(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class RotaryPositionalEmbedding(nn.Module):
def __init__(self,
head_dim,
cp_split_hw=None
):
"""Rotary positional embedding for 3D
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
"""
super().__init__()
self.head_dim = head_dim
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
self.cp_split_hw = cp_split_hw
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
self.base = 10000
self.freqs_dict = {}
def register_grid_size(self, grid_size):
if grid_size not in self.freqs_dict:
self.freqs_dict.update({
grid_size: self.precompute_freqs_cis_3d(grid_size)
})
def precompute_freqs_cis_3d(self, grid_size):
num_frames, height, width = grid_size
dim_t = self.head_dim - 4 * (self.head_dim // 6)
dim_h = 2 * (self.head_dim // 6)
dim_w = 2 * (self.head_dim // 6)
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
grid_t = torch.from_numpy(grid_t).float()
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# with torch.no_grad():
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
return freqs
def forward(self, q, k, grid_size):
"""3D RoPE.
Args:
query: [B, head, seq, head_dim]
key: [B, head, seq, head_dim]
Returns:
query and key with the same shape as input.
"""
if grid_size not in self.freqs_dict:
self.register_grid_size(grid_size)
freqs_cis = self.freqs_dict[grid_size].to(q.device)
q_, k_ = q.float(), k.float()
freqs_cis = freqs_cis.float().to(q.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
q_ = (q_ * cos) + (rotate_half(q_) * sin)
k_ = (k_ * cos) + (rotate_half(k_) * sin)
return q_.type_as(q), k_.type_as(k)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
def _process_attn(self, q, k, v, shape):
q = rearrange(q, "B H S D -> B S (H D)")
k = rearrange(k, "B H S D -> B S (H D)")
v = rearrange(v, "B H S D -> B S (H D)")
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
return x
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
# process the noise tokens
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape)
# merge x_cond and x_noise
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
else:
x = self._process_attn(q, k, v, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
T, H, W = shape
k_cache, v_cache = kv_cache
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
if k_cache.shape[0] == 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=2).contiguous()
v_full = torch.cat([v_cache, v], dim=2).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, :, -N:].contiguous()
x = self._process_attn(q, k_full, v_full, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
q = rearrange(q, "B S H D -> B S (H D)")
k = rearrange(k, "B S H D -> B S (H D)")
v = rearrange(v, "B S H D -> B S (H D)")
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = x.view(B, -1, C)
x = self.proj(x)
return x
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
if num_cond_latents is None or num_cond_latents == 0:
return self._process_cross_attn(x, cond, kv_seqlen)
else:
B, N, C = x.shape
if num_cond_latents is not None and num_cond_latents > 0:
assert shape is not None, "SHOULD pass in the shape"
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
output = torch.cat([
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
output_noise
], dim=1).contiguous()
else:
raise NotImplementedError
return output
class LayerNorm_FP32(nn.LayerNorm):
def __init__(self, dim, eps, elementwise_affine):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
out = F.layer_norm(
inputs.float(),
self.normalized_shape,
None if self.weight is None else self.weight.float(),
None if self.bias is None else self.bias.float() ,
self.eps
).to(origin_dtype)
return out
def modulate_fp32(norm_func, x, shift, scale):
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
# ensure the modulation params be fp32
assert shift.dtype == torch.float32, scale.dtype == torch.float32
dtype = x.dtype
x = norm_func(x.to(torch.float32))
x = x * (scale + 1) + shift
x = x.to(dtype)
return x
class FinalLayer_FP32(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
super().__init__()
self.hidden_size = hidden_size
self.num_patch = num_patch
self.out_channels = out_channels
self.adaln_tembed_dim = adaln_tembed_dim
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
def forward(self, x, t, latent_shape):
# timestep shape: [B, T, C]
assert t.dtype == torch.float32
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
return x
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, t_embed_dim, frequency_embedding_size=256):
super().__init__()
self.t_embed_dim = t_embed_dim
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
nn.SiLU(),
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
)
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations.
"""
def __init__(self, in_channels, hidden_size):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.y_proj = nn.Sequential(
nn.Linear(in_channels, hidden_size, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, caption):
B, _, N, C = caption.shape
caption = self.y_proj(caption)
return caption
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
B, C, T, H, W = x.shape
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class LongCatSingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: int,
adaln_tembed_dim: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params=None,
cp_split_hw=None
):
super().__init__()
self.hidden_size = hidden_size
# scale and gate modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
)
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
self.attn = Attention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
self.cross_attn = MultiHeadCrossAttention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
)
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
"""
x: [B, N, C]
y: [1, N_valid_tokens, C]
t: [B, T, C_t]
y_seqlen: [B]; type of a list
latent_shape: latent shape of a single item
"""
x_dtype = x.dtype
B, N, C = x.shape
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
# self attn with modulation
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
if kv_cache is not None:
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
else:
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
if return_kv:
x_s, kv_cache = attn_outputs
else:
x_s = attn_outputs
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
# cross attn
if not skip_crs_attn:
if kv_cache is not None:
num_cond_latents = None
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
if return_kv:
return x, kv_cache
else:
return x
class LongCatVideoTransformer3DModel(torch.nn.Module):
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
depth: int = 48,
num_heads: int = 32,
caption_channels: int = 4096,
mlp_ratio: int = 4,
adaln_tembed_dim: int = 512,
frequency_embedding_size: int = 256,
# default params
patch_size: Tuple[int] = (1, 2, 2),
# attention config
enable_flashattn3: bool = False,
enable_flashattn2: bool = True,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
cp_split_hw: Optional[List[int]] = [1, 1],
text_tokens_zero_pad: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.cp_split_hw = cp_split_hw
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
)
self.blocks = nn.ModuleList(
[
LongCatSingleStreamBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
adaln_tembed_dim=adaln_tembed_dim,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
for i in range(depth)
]
)
self.final_layer = FinalLayer_FP32(
hidden_size,
np.prod(self.patch_size),
out_channels,
adaln_tembed_dim,
)
self.gradient_checkpointing = False
self.text_tokens_zero_pad = text_tokens_zero_pad
self.lora_dict = {}
self.active_loras = []
def enable_loras(self, lora_key_list=[]):
self.disable_all_loras()
module_loras = {} # {module_name: [lora1, lora2, ...]}
model_device = next(self.parameters()).device
model_dtype = next(self.parameters()).dtype
for lora_key in lora_key_list:
if lora_key in self.lora_dict:
for lora in self.lora_dict[lora_key].loras:
lora.to(model_device, dtype=model_dtype, non_blocking=True)
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
if module_name not in module_loras:
module_loras[module_name] = []
module_loras[module_name].append(lora)
self.active_loras.append(lora_key)
for module_name, loras in module_loras.items():
module = self._get_module_by_name(module_name)
if not hasattr(module, 'org_forward'):
module.org_forward = module.forward
module.forward = self._create_multi_lora_forward(module, loras)
def _create_multi_lora_forward(self, module, loras):
def multi_lora_forward(x, *args, **kwargs):
weight_dtype = x.dtype
org_output = module.org_forward(x, *args, **kwargs)
total_lora_output = 0
for lora in loras:
if lora.use_lora:
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
lx = lora.lora_up(lx)
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
total_lora_output += lora_output
return org_output + total_lora_output
return multi_lora_forward
def _get_module_by_name(self, module_name):
try:
module = self
for part in module_name.split('.'):
module = getattr(module, part)
return module
except AttributeError as e:
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
def disable_all_loras(self):
for name, module in self.named_modules():
if hasattr(module, 'org_forward'):
module.forward = module.org_forward
delattr(module, 'org_forward')
for lora_key, lora_network in self.lora_dict.items():
for lora in lora_network.loras:
lora.to("cpu")
self.active_loras.clear()
def enable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = True
def disable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = False
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states,
encoder_attention_mask=None,
num_cond_latents=0,
return_kv=False,
kv_cache_dict={},
skip_crs_attn=False,
offload_kv_cache=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
B, _, T, H, W = hidden_states.shape
N_t = T // self.patch_size[0]
N_h = H // self.patch_size[1]
N_w = W // self.patch_size[2]
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
# expand the shape of timestep from [B] to [B, T]
if len(timestep.shape) == 1:
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
timestep[:, :num_cond_latents] = 0
dtype = hidden_states.dtype
hidden_states = hidden_states.to(dtype)
timestep = timestep.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
else:
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
# blocks
kv_cache_dict_ret = {}
for i, block in enumerate(self.blocks):
block_outputs = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=hidden_states,
y=encoder_hidden_states,
t=t,
y_seqlen=y_seqlens,
latent_shape=(N_t, N_h, N_w),
num_cond_latents=num_cond_latents,
return_kv=return_kv,
kv_cache=kv_cache_dict.get(i, None),
skip_crs_attn=skip_crs_attn,
)
if return_kv:
hidden_states, kv_cache = block_outputs
if offload_kv_cache:
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
else:
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
else:
hidden_states = block_outputs
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
# cast to float32 for better accuracy
hidden_states = hidden_states.to(torch.float32)
if return_kv:
return hidden_states, kv_cache_dict_ret
else:
return hidden_states
def unpatchify(self, x, N_t, N_h, N_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
@staticmethod
def state_dict_converter():
return LongCatVideoTransformer3DModelDictConverter()
class LongCatVideoTransformer3DModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

252
diffsynth/models/lora.py Normal file
View File

@@ -0,0 +1,252 @@
import torch
from .sd_unet import SDUNet
from .sdxl_unet import SDXLUNet
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
class LoRAFromCivitai:
def __init__(self):
self.supported_model_classes = []
self.lora_prefix = []
self.renamed_lora_prefix = {}
self.special_keys = {}
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
for key in state_dict:
if ".lora_up" in key:
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
for special_key in self.special_keys:
target_name = target_name.replace(special_key, self.special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
target_name = target_name[len(lora_prefix):]
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
if model_resource == "diffusers":
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
elif model_resource == "civitai":
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
for model_resource in ["diffusers", "civitai"]:
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return lora_prefix, model_resource
except:
pass
return None
class SDLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDUNet, SDTextEncoder]
self.lora_prefix = ["lora_unet_", "lora_te_"]
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
}
class SDXLLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
self.renamed_lora_prefix = {"lora_te2_": "2"}
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "conditioner.embedders.0.transformer.text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
}
class FluxLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [FluxDiT, FluxDiT]
self.lora_prefix = ["lora_unet_", "transformer."]
self.renamed_lora_prefix = {}
self.special_keys = {
"single.blocks": "single_blocks",
"double.blocks": "double_blocks",
"img.attn": "img_attn",
"img.mlp": "img_mlp",
"img.mod": "img_mod",
"txt.attn": "txt_attn",
"txt.mlp": "txt_mlp",
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) > 0:
return "", ""
except:
pass
return None
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]

View File

@@ -1,111 +0,0 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
import importlib, json, torch
class ModelPool:
def __init__(self):
self.model = []
self.model_name = []
self.model_path = []
def import_model_class(self, model_class):
split = model_class.rfind(".")
model_resource, model_class = model_class[:split], model_class[split+1:]
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
return model_class
def need_to_enable_vram_management(self, vram_config):
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
def fetch_module_map(self, model_class, vram_config):
if self.need_to_enable_vram_management(vram_config):
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
else:
module_map = {self.import_model_class(model_class): AutoWrappedModule}
else:
module_map = None
return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
state_dict_converter = self.import_model_class(config["state_dict_converter"])
else:
state_dict_converter = None
module_map = self.fetch_module_map(config["model_class"], vram_config)
model = load_model(
model_class, path, model_config,
vram_config["computation_dtype"], vram_config["computation_device"],
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
)
return model
def default_vram_config(self):
vram_config = {
"offload_dtype": None,
"offload_device": None,
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cpu",
"computation_dtype": torch.bfloat16,
"computation_device": "cpu",
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
model_hash = hash_model_file(path)
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]
self.model_name.append(model_name)
self.model_path.append(path)
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
loaded = True
if not loaded:
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
def fetch_model(self, model_name, index=None):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available. This is not an error.")
model = None
elif len(fetched_models) == 1:
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
model = fetched_models[0]
else:
if index is None:
model = fetched_models[0]
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
elif isinstance(index, int):
model = fetched_models[:index]
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
else:
model = fetched_models
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
return model
def clear_parameters(self, model: torch.nn.Module):
for name, module in model.named_children():
self.clear_parameters(module)
for name, param in model.named_parameters(recurse=False):
setattr(model, name, None)

View File

@@ -0,0 +1,471 @@
import os, torch, hashlib, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List
from .downloader import download_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .lora import get_lora_loaders
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sdxl_controlnet import SDXLControlNetUnion
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
model.load_state_dict(model_state_dict)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
if os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(downloaded_files + file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None):
print(f"Loading models from: {file_path}")
if os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None):
for file_path in file_path_list:
self.load_model(file_path, model_names)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -1,161 +0,0 @@
import torch
from PIL import Image
class NexusGenAutoregressiveModel(torch.nn.Module):
def __init__(self, max_length=1024, max_pixels=262640):
super(NexusGenAutoregressiveModel, self).__init__()
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
from transformers import Qwen2_5_VLConfig
self.max_length = max_length
self.max_pixels = max_pixels
model_config = Qwen2_5_VLConfig(**{
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
},
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.49.0",
"use_cache": False,
"use_sliding_window": False,
"video_token_id": 151656,
"vision_config": {
"hidden_size": 1280,
"in_chans": 3,
"model_type": "qwen2_5_vl",
"spatial_patch_size": 14,
"tokens_per_second": 2,
"torch_dtype": "bfloat16"
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
})
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
self.processor = None
def load_processor(self, path):
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
@staticmethod
def state_dict_converter():
return NexusGenAutoregressiveModelStateDictConverter()
def bound_image(self, image, max_pixels=262640):
from qwen_vl_utils import smart_resize
resized_height, resized_width = smart_resize(
image.height,
image.width,
max_pixels=max_pixels,
)
return image.resize((resized_width, resized_height))
def get_editing_msg(self, instruction):
if '<image>' not in instruction:
instruction = '<image> ' + instruction
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
return messages
def get_generation_msg(self, instruction):
instruction = "Generate an image according to the following description: {}".format(instruction)
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
return messages
def forward(self, instruction, ref_image=None, num_img_tokens=81):
"""
Generate target embeddings for the given instruction and reference image.
"""
if ref_image is not None:
messages = self.get_editing_msg(instruction)
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
else:
messages = self.get_generation_msg(instruction)
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
return output_image_embeddings
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
inputs = processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
input_embeds = model.model.embed_tokens(inputs['input_ids'])
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
image_prefill_embeds = model.image_prefill_embeds(
torch.arange(81, device=model.device).long()
)
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
position_ids, _ = model.get_rope_index(
inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
class NexusGenAutoregressiveModelStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict = {"model." + key: value for key, value in state_dict.items()}
return state_dict

File diff suppressed because it is too large Load Diff

View File

@@ -1,417 +0,0 @@
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
from transformers.modeling_rope_utils import _compute_default_rope_parameters
self.rope_init_fn = _compute_default_rope_parameters
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Qwen2_5_VLAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.rope_scaling = config.rope_scaling
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
from transformers.activations import ACT2FN
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2_5_VLDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class NexusGenImageEmbeddingMerger(nn.Module):
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
super().__init__()
from transformers import Qwen2_5_VLConfig
from transformers.activations import ACT2FN
config = Qwen2_5_VLConfig(**{
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
},
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.49.0",
"use_cache": False,
"use_sliding_window": False,
"video_token_id": 151656,
"vision_config": {
"hidden_size": 1280,
"in_chans": 3,
"model_type": "qwen2_5_vl",
"spatial_patch_size": 14,
"tokens_per_second": 2,
"torch_dtype": "bfloat16"
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
})
self.config = config
self.num_layers = num_layers
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
nn.Linear(config.hidden_size, out_channel * expand_ratio),
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
def get_position_ids(self, image_grid_thw):
"""
Generates position ids for the input embeddings grid.
modified from the qwen2_vl mrope.
"""
batch_size = image_grid_thw.shape[0]
spatial_merge_size = self.config.vision_config.spatial_merge_size
t, h, w = (
image_grid_thw[0][0],
image_grid_thw[0][1],
image_grid_thw[0][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
scale_h = self.base_grid[0][1].item() / h.item()
scale_w = self.base_grid[0][2].item() / w.item()
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
# 3, B, L
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
return position_ids
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
position_ids = self.get_position_ids(embeds_grid)
hidden_states = embeds
if ref_embeds is not None:
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, position_embeddings)
hidden_states = self.projector(hidden_states)
return hidden_states
@staticmethod
def state_dict_converter():
return NexusGenMergerStateDictConverter()
class NexusGenMergerStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
return merger_state_dict
class NexusGenAdapter(nn.Module):
"""
Adapter for Nexus-Gen generation decoder.
"""
def __init__(self, input_dim=3584, output_dim=4096):
super(NexusGenAdapter, self).__init__()
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
nn.LayerNorm(output_dim), nn.ReLU(),
nn.Linear(output_dim, output_dim),
nn.LayerNorm(output_dim))
def forward(self, x):
return self.adapter(x)
@staticmethod
def state_dict_converter():
return NexusGenAdapterStateDictConverter()
class NexusGenAdapterStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
return adapter_state_dict

View File

@@ -1,56 +0,0 @@
import torch
import torch.nn as nn
from .general_modules import RMSNorm
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072):
super().__init__()
self.x_rms = RMSNorm(dim, eps=1e-6)
self.y_rms = RMSNorm(dim, eps=1e-6)
self.input_proj = nn.Linear(dim, dim)
self.act = nn.GELU()
self.output_proj = nn.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
def init_weights(self):
# zero initialize output_proj
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
):
super().__init__()
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
self.controlnet_blocks = nn.ModuleList(
[
BlockWiseControlBlock(dim)
for _ in range(num_layers)
]
)
def init_weight(self):
nn.init.zeros_(self.img_in.weight)
nn.init.zeros_(self.img_in.bias)
for block in self.controlnet_blocks:
block.init_weights()
def process_controlnet_conditioning(self, controlnet_conditioning):
return self.img_in(controlnet_conditioning)
def blockwise_forward(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)

View File

@@ -1,685 +0,0 @@
import torch, math, functools
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
def apply_rotary_emb_qwen(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
):
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class QwenEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
self.rope_cache = {}
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(
index,
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
if isinstance(video_fhw, list):
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
_, height, width = video_fhw
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
else:
max_vid_index = max(height, width)
required_len = max_vid_index + max(txt_seq_lens)
cur_max_len = self.pos_freqs.shape[0]
if required_len <= cur_max_len:
return
new_max_len = math.ceil(required_len / 512) * 512
pos_index = torch.arange(new_max_len)
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
return
def forward(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
def forward_sampling(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
frame_0, height_0, width_0 = video_fhw[0]
rope_key_0 = f"0_{height_0}_{width_0}"
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
h_indices = torch.linspace(0, height_0 - 1, height).long()
w_indices = torch.linspace(0, width_0 - 1, width).long()
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
seq_lens = frame * height * width
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone()
vid_freqs.append(self.rope_cache[rope_key].contiguous())
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
video_fhw = [video_fhw]
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenFeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
dropout: float = 0.0,
):
super().__init__()
inner_dim = int(dim * 4)
self.net = nn.ModuleList([])
self.net.append(ApproximateGELU(dim, inner_dim))
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class QwenDoubleStreamAttention(nn.Module):
def __init__(
self,
dim_a,
dim_b,
num_heads,
head_dim,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = nn.Linear(dim_a, dim_a)
self.to_k = nn.Linear(dim_a, dim_a)
self.to_v = nn.Linear(dim_a, dim_a)
self.norm_q = RMSNorm(head_dim, eps=1e-6)
self.norm_k = RMSNorm(head_dim, eps=1e-6)
self.add_q_proj = nn.Linear(dim_b, dim_b)
self.add_k_proj = nn.Linear(dim_b, dim_b)
self.add_v_proj = nn.Linear(dim_b, dim_b)
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
self.to_add_out = nn.Linear(dim_b, dim_b)
def forward(
self,
image: torch.FloatTensor,
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
seq_txt = txt_q.shape[1]
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
joint_q = torch.cat([txt_q, img_q], dim=2)
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
img_attn_output = self.to_out(img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.img_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.attn = QwenDoubleStreamAttention(
dim_a=dim,
dim_b=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
)
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
self.txt_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True),
)
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward(
self,
image: torch.Tensor,
text: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
img_attn_out, txt_attn_out = self.attn(
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
image = image + img_gate * img_attn_out
text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
img_mlp_out = self.img_mlp(img_modulated_2)
txt_mlp_out = self.txt_mlp(txt_modulated_2)
image = image + img_gate_2 * img_mlp_out
text = text + txt_gate_2 * txt_mlp_out
return text, image
class QwenImageDiT(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
use_layer3d_rope: bool = False,
use_additional_t_cond: bool = False,
):
super().__init__()
if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072)
self.txt_in = nn.Linear(3584, 3072)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, single=True)
self.proj_out = nn.Linear(3072, 64)
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
# prompt_emb
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# image_rotary_emb
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# attention_mask
repeat_dim = latents.shape[1]
max_masks = entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask]
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
total_seq_len = sum(seq_lens) + image.shape[1]
patched_masks = []
for i in range(N):
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
patched_masks.append(patched_mask)
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
# prompt-image attention mask
image_start = sum(seq_lens)
image_end = total_seq_len
cumsum = [0]
single_image_seq = image_end - image_start
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
for i in range(N):
prompt_start = cumsum[i]
prompt_end = cumsum[i+1]
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# repeat image mask to match the single image sequence length
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt attention mask, let the prompt tokens not attend to each other
for i in range(N):
for j in range(N):
if i == j:
continue
start_i, end_i = cumsum[i], cumsum[i+1]
start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb))
conditioning = self.time_text_embed(timestep, image.dtype)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
for block in self.transformer_blocks:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return image

View File

@@ -1,128 +0,0 @@
import torch
class CompressedMLP(torch.nn.Module):
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
def forward(self, x, residual=None):
x = self.proj_in(x)
if residual is not None: x = x + residual
x = self.proj_out(x)
return x
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
super().__init__()
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
self.lora_a_dim = lora_a_dim
self.lora_b_dim = lora_b_dim
self.rank = rank
def forward(self, x, residual=None):
lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)
lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)
return lora_a, lora_b
class SequencialMLP(torch.nn.Module):
def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)
self.length = length
self.in_dim = in_dim
self.mid_dim = mid_dim
def forward(self, x):
x = x.view(self.length, self.in_dim)
x = self.proj_in(x)
x = x.view(1, self.length * self.mid_dim)
x = self.proj_out(x)
return x
class LoRATrainerBlock(torch.nn.Module):
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = lora_patterns
self.block_id = block_id
self.layers = []
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
self.layers = torch.nn.ModuleList(self.layers)
if use_residual:
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
else:
self.proj_residual = None
def forward(self, x, residual=None):
lora = {}
if self.proj_residual is not None: residual = self.proj_residual(residual)
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
name = lora_pattern[0]
lora_a, lora_b = layer(x, residual=residual)
lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
return lora
class QwenImageImage2LoRAModel(torch.nn.Module):
def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = [
[
("attn.to_q", 3072, 3072),
("attn.to_k", 3072, 3072),
("attn.to_v", 3072, 3072),
("attn.to_out.0", 3072, 3072),
],
[
("img_mlp.net.2", 3072*4, 3072),
("img_mod.1", 3072, 3072*6),
],
[
("attn.add_q_proj", 3072, 3072),
("attn.add_k_proj", 3072, 3072),
("attn.add_v_proj", 3072, 3072),
("attn.to_add_out", 3072, 3072),
],
[
("txt_mlp.net.2", 3072*4, 3072),
("txt_mod.1", 3072, 3072*6),
],
]
self.num_blocks = num_blocks
self.blocks = []
for lora_patterns in self.lora_patterns:
for block_id in range(self.num_blocks):
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))
self.blocks = torch.nn.ModuleList(self.blocks)
self.residual_scale = 0.05
self.use_residual = use_residual
def forward(self, x, residual=None):
if residual is not None:
if self.use_residual:
residual = residual * self.residual_scale
else:
residual = None
lora = {}
for block in self.blocks:
lora.update(block(x, residual))
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if ".proj_a." in name:
state_dict[name] = state_dict[name] * 0.3
elif ".proj_b.proj_out." in name:
state_dict[name] = state_dict[name] * 0
elif ".proj_residual.proj_out." in name:
state_dict[name] = state_dict[name] * 0.3
self.load_state_dict(state_dict)

View File

@@ -1,190 +0,0 @@
import torch
from typing import Optional, Union
class QwenImageTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel
config = Qwen2_5_VLConfig(**{
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"text_config": {
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": None,
"initializer_range": 0.02,
"intermediate_size": 18944,
"layer_types": [
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention"
],
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl_text",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": None,
"torch_dtype": "float32",
"use_cache": True,
"use_sliding_window": False,
"video_token_id": None,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
},
"tie_word_embeddings": False,
"torch_dtype": "float32",
"transformers_version": "4.54.0",
"use_cache": True,
"use_sliding_window": False,
"video_token_id": 151656,
"vision_config": {
"depth": 32,
"fullatt_block_indexes": [
7,
15,
23,
31
],
"hidden_act": "silu",
"hidden_size": 1280,
"in_channels": 3,
"in_chans": 3,
"initializer_range": 0.02,
"intermediate_size": 3420,
"model_type": "qwen2_5_vl",
"num_heads": 16,
"out_hidden_size": 3584,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"tokens_per_second": 2,
"torch_dtype": "float32",
"window_size": 112
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
})
self.model = Qwen2_5_VLModel(config)
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
output_attentions = False
output_hidden_states = True
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return outputs.hidden_states

View File

@@ -1,726 +0,0 @@
import torch
from typing import List, Optional, Tuple, Union
from torch import nn
CACHE_T = 2
class QwenImageCausalConv3d(torch.nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = torch.nn.functional.pad(x, padding)
return super().forward(x)
class QwenImageRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class QwenImageResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = torch.nn.SiLU()
# layers
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class QwenImageAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = QwenImageRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class QwenImageUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class QwenImageResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class QwenImageMidBlock(nn.Module):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
super().__init__()
self.dim = dim
# Create the components
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(QwenImageAttentionBlock(dim))
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class QwenImageEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = torch.nn.SiLU()
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = torch.nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class QwenImageUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class QwenImageDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = torch.nn.SiLU()
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
in_dim = in_dim // 2
# Determine if we need upsampling
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
# Create and add the upsampling block
up_block = QwenImageUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class QwenImageVAE(torch.nn.Module):
def __init__(
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
image_channels: int = 3,
) -> None:
super().__init__()
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
)
mean = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
std = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1)
self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1)
def encode(self, x, **kwargs):
x = x.unsqueeze(2)
x = self.encoder(x)
x = self.quant_conv(x)
x = x[:, :16]
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
x = (x - mean) * std
x = x.squeeze(2)
return x
def decode(self, x, **kwargs):
x = x.unsqueeze(2)
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
x = x / std + mean
x = self.post_quant_conv(x)
x = self.decoder(x)
x = x.squeeze(2)
return x

798
diffsynth/models/sd3_dit.py Normal file
View File

@@ -0,0 +1,798 @@
import torch
from einops import rearrange
from .svd_unet import TemporalTimesteps
from .tiler import TileWorker
class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
super().__init__()
self.pos_embed_max_size = pos_embed_max_size
self.patch_size = patch_size
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
def cropped_pos_embed(self, height, width):
height = height // self.patch_size
width = width // self.patch_size
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
return spatial_pos_embed
def forward(self, latent):
height, width = latent.shape[-2:]
latent = self.proj(latent)
latent = latent.flatten(2).transpose(1, 2)
pos_embed = self.cropped_pos_embed(height, width)
return latent + pos_embed
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
def forward(self, timestep, dtype):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
return time_emb
class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift
return x
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class JointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.only_out_a = only_out_a
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
def forward(self, hidden_states_a, hidden_states_b):
batch_size = hidden_states_a.shape[0]
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class JointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class JointTransformerFinalBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim, single=True)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
return hidden_states_a, hidden_states_b
class SD3DiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
self.time_embedder = TimestepEmbeddings(256, 1536)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
self.context_embedder = torch.nn.Linear(4096, 1536)
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
self.norm_out = AdaLayerNorm(1536, single=True)
self.proj_out = torch.nn.Linear(1536, 64)
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
if tiled:
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
prompt_emb = self.context_embedder(prompt_emb)
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embedder(hidden_states)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
hidden_states = self.norm_out(hidden_states, conditioning)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
@staticmethod
def state_dict_converter():
return SD3DiTStateDictConverter()
class SD3DiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"context_embedder": "context_embedder",
"pos_embed.pos_embed": "pos_embedder.pos_embed",
"pos_embed.proj": "pos_embedder.proj",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "norm_out.linear",
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "pos_embed.pos_embed":
param = param.reshape((1, 192, 192, 1536))
state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in rename_dict:
state_dict_[rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name == "model.diffusion_model.pos_embed":
param = param.reshape((1, 192, 192, 1536))
if isinstance(rename_dict[name], str):
state_dict_[rename_dict[name]] = param
else:
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
state_dict_[name_] = param
return state_dict_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
import torch
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
from .sd_unet import ResnetBlock, UpSampler
from .tiler import TileWorker
class SD3VAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 1.5305 # Different from SD 1.x
self.shift_factor = 0.0609 # Different from SD 1.x
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
UpSampler(256),
# UpDecoderBlock2D
ResnetBlock(256, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = sample / self.scaling_factor + self.shift_factor
hidden_states = self.conv_in(hidden_states)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
@staticmethod
def state_dict_converter():
return SDVAEDecoderStateDictConverter()

View File

@@ -0,0 +1,95 @@
import torch
from .sd_unet import ResnetBlock, DownSampler
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
from .tiler import TileWorker
from einops import rearrange
class SD3VAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 1.5305 # Different from SD 1.x
self.shift_factor = 0.0609 # Different from SD 1.x
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# DownEncoderBlock2D
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
DownSampler(128, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(128, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
DownSampler(256, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(256, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
DownSampler(512, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = self.conv_in(sample)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states[:, :16]
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
return hidden_states
def encode_video(self, sample, batch_size=8):
B = sample.shape[0]
hidden_states = []
for i in range(0, sample.shape[2], batch_size):
j = min(i + batch_size, sample.shape[2])
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
hidden_states_batch = self(sample_batch)
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
hidden_states.append(hidden_states_batch)
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
@staticmethod
def state_dict_converter():
return SDVAEEncoderStateDictConverter()

View File

@@ -0,0 +1,589 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .tiler import TileWorker
class ControlNetConditioningLayer(torch.nn.Module):
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
super().__init__()
self.blocks = torch.nn.ModuleList([])
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
self.blocks.append(torch.nn.SiLU())
for i in range(1, len(channels) - 2):
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
self.blocks.append(torch.nn.SiLU())
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
self.blocks.append(torch.nn.SiLU())
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
def forward(self, conditioning):
for block in self.blocks:
conditioning = block(conditioning)
return conditioning
class SDControlNet(torch.nn.Module):
def __init__(self, global_pool=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
self.blocks = torch.nn.ModuleList([
# CrossAttnDownBlock2D
ResnetBlock(320, 320, 1280),
AttentionBlock(8, 40, 320, 1, 768),
PushBlock(),
ResnetBlock(320, 320, 1280),
AttentionBlock(8, 40, 320, 1, 768),
PushBlock(),
DownSampler(320),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(320, 640, 1280),
AttentionBlock(8, 80, 640, 1, 768),
PushBlock(),
ResnetBlock(640, 640, 1280),
AttentionBlock(8, 80, 640, 1, 768),
PushBlock(),
DownSampler(640),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(640, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
PushBlock(),
DownSampler(1280),
PushBlock(),
# DownBlock2D
ResnetBlock(1280, 1280, 1280),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
PushBlock(),
# UNetMidBlock2DCrossAttn
ResnetBlock(1280, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
ResnetBlock(1280, 1280, 1280),
PushBlock()
])
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
])
self.global_pool = global_pool
def forward(
self,
sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32,
**kwargs
):
# 1. time
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
# pool
if self.global_pool:
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return SDControlNetStateDictConverter()
class SDControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
]
# controlnet_rename_dict
controlnet_rename_dict = {
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
}
# Rename each parameter
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
pass
elif name in controlnet_rename_dict:
names = controlnet_rename_dict[name].split(".")
elif names[0] == "controlnet_down_blocks":
names[0] = "controlnet_blocks"
elif names[0] == "controlnet_mid_block":
names = ["controlnet_blocks", "12", names[-1]]
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
if names[0] == "mid_block":
names.insert(1, "0")
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
block_type_with_id = ".".join(names[:4])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:4])
names = ["blocks", str(block_id[block_type])] + names[4:]
if "ff" in names:
ff_index = names.index("ff")
component = ".".join(names[ff_index:ff_index+3])
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
names = names[:ff_index] + [component] + names[ff_index+3:]
if "to_out" in names:
names.pop(names.index("to_out") + 1)
else:
raise ValueError(f"Unknown parameters: {name}")
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
if rename_dict[name] in [
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
]:
continue
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
# For controlnets in diffusers format
return self.from_diffusers(state_dict)
rename_dict = {
"control_model.time_embed.0.weight": "time_embedding.0.weight",
"control_model.time_embed.0.bias": "time_embedding.0.bias",
"control_model.time_embed.2.weight": "time_embedding.2.weight",
"control_model.time_embed.2.bias": "time_embedding.2.bias",
"control_model.input_blocks.0.0.weight": "conv_in.weight",
"control_model.input_blocks.0.0.bias": "conv_in.bias",
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

Some files were not shown because too many files have changed in this diff Show More