Compare commits

..

1 Commits
doc ... ExVideo

Author SHA1 Message Date
Artiprocher
a076adf592 ExVideo for AnimateDiff 2024-07-26 14:35:18 +08:00
242 changed files with 2927 additions and 573937 deletions

View File

@@ -1,29 +0,0 @@
name: release
on:
push:
tags:
- 'v**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-publish
cancel-in-progress: true
jobs:
build-n-publish:
runs-on: ubuntu-20.04
#if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Install wheel
run: pip install wheel && pip install -r requirements.txt
- name: Build DiffSynth
run: python setup.py sdist bdist_wheel
- name: Publish package to PyPI
run: |
pip install twine
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -0,0 +1,267 @@
import torch, json, os, imageio
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel
def lets_dance(
unet: SDUNet,
motion_modules: SDMotionModel,
sample,
timestep,
encoder_hidden_states,
use_gradient_checkpointing=False,
):
# 1. ControlNet (skip)
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 4. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if use_gradient_checkpointing:
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 4.2 AnimateDiff
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
if use_gradient_checkpointing:
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_modules.motion_modules[motion_module_id]),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack)
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
self.text = [i["text"] for i in metadata]
self.steps_per_epoch = steps_per_epoch
self.training_shapes = training_shapes
self.frame_process = []
for max_num_frames, interval, num_frames, height, width in training_shapes:
self.frame_process.append(v2.Compose([
v2.Resize(size=max(height, width), antialias=True),
v2.CenterCrop(size=(height, width)),
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
]))
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = torch.tensor(frame, dtype=torch.float32)
frame = rearrange(frame, "H W C -> 1 C H W")
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.concat(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
return frames
def load_video(self, file_path, training_shape_id):
data = {}
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
frame_process = self.frame_process[training_shape_id]
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
if frames is None:
return None
else:
data[f"frames_{training_shape_id}"] = frames
data[f"start_frame_id_{training_shape_id}"] = start_frame_id
return data
def __getitem__(self, index):
video_data = {}
for training_shape_id in range(len(self.training_shapes)):
while True:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
if isinstance(text, list):
text = text[torch.randint(0, len(text), (1,))[0]]
video_file = self.path[data_id]
try:
data = self.load_video(video_file, training_shape_id)
except:
data = None
if data is not None:
data[f"text_{training_shape_id}"] = text
break
video_data.update(data)
return video_data
def __len__(self):
return self.steps_per_epoch
class LightningModel(pl.LightningModule):
def __init__(self, learning_rate=1e-5, sd_ckpt_path=None):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cpu")
model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path))
# Initialize motion modules
model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device)
# Build pipeline
self.pipe = SDVideoPipeline.from_model_manager(model_manager)
self.pipe.vae_encoder.eval()
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.vae_decoder.eval()
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.text_encoder.eval()
self.pipe.text_encoder.requires_grad_(False)
self.pipe.unet.eval()
self.pipe.unet.requires_grad_(False)
self.pipe.motion_modules.train()
self.pipe.motion_modules.requires_grad_(True)
# Reset the scheduler
self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
self.pipe.scheduler.set_timesteps(1000)
# Other parameters
self.learning_rate = learning_rate
def encode_video_with_vae(self, video):
video = video.to(device=self.device, dtype=self.dtype)
video = video.unsqueeze(0)
latents = self.pipe.vae_encoder.encode_video(video, batch_size=16)
latents = rearrange(latents[0], "C T H W -> T C H W")
return latents
def calculate_loss(self, prompt, frames):
with torch.no_grad():
# Call video encoder
latents = self.encode_video_with_vae(frames)
# Call text encoder
prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77)
prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1)
# Call scheduler
timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0]
noise = torch.randn_like(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
# Calculate loss
model_pred = lets_dance(
self.pipe.unet, self.pipe.motion_modules,
sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep
)
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
return loss
def training_step(self, batch, batch_idx):
# Loss
frames = batch["frames_0"][0]
prompt = batch["text_0"][0]
loss = self.calculate_loss(prompt, frames)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters()))
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
checkpoint["trainable_param_names"] = trainable_param_names
if __name__ == '__main__':
# dataset and data loader
dataset = TextVideoDataset(
"/data/zhongjie/datasets/opensoraplan/data/processed",
"/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json",
training_shapes=[(16, 1, 16, 512, 512)],
steps_per_epoch=7*10000,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=4
)
# model
model = LightningModel(
learning_rate=1e-5,
sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors",
)
# train
trainer = pl.Trainer(
max_epochs=100000,
accelerator="gpu",
devices="auto",
strategy="deepspeed_stage_1",
precision="16-mixed",
default_root_dir="/data/zhongjie/models/train_extended_animatediff",
accumulate_grad_batches=1,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
)
trainer.fit(
model=model,
train_dataloaders=train_loader,
ckpt_path=None
)

224
README.md
View File

@@ -1,148 +1,92 @@
# DiffSynth Studio
[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
<p align="center">
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>
## Introduction
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
Until now, DiffSynth Studio has supported the following models:
## Roadmap
* [CogVideo](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
* [ESRGAN](https://github.com/xinntao/ESRGAN)
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
* [AnimateDiff](https://github.com/guoyww/animatediff/)
* [ControlNet](https://github.com/lllyasviel/ControlNet)
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
* Since OLSS requires additional training, we don't implement it in this project.
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
* Demo videos are shown on Bilibili, including three tasks.
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
* The source codes are released in this project.
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
* Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
- Text to video
- Video editing
- Self-upscaling
- Video interpolation
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
- Use it in our [WebUI](#usage-in-webui).
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and additional models will be available soon.
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- The source codes are released in this project.
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
- Demo videos are shown on Bilibili, including three tasks.
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
- Since OLSS requires additional training, we don't implement it in this project.
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
* [ControlNet](https://github.com/lllyasviel/ControlNet)
* [AnimateDiff](https://github.com/guoyww/animatediff/)
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
* [ESRGAN](https://github.com/xinntao/ESRGAN)
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
## Installation
Install from source code (recommended):
Create Python environment:
```
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
conda env create -f environment.yml
```
Or install from pypi:
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
Enter the Python environment:
```
pip install diffsynth
conda activate DiffSynthStudio
```
## Usage (in Python code)
The Python examples are in [`examples`](./examples/). We provide an overview here.
### Download Models
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
```python
from diffsynth import download_models
download_models(["FLUX.1-dev", "Kolors"])
```
Download your own models.
```python
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
# From Modelscope (recommended)
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
# From Huggingface
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
```
### Video Synthesis
#### Text-to-video using CogVideoX-5B
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
#### Long Video Synthesis
### Long Video Synthesis
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
### Image Synthesis
#### Toon Shading
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|512*512|1024*1024|2048*2048|4096*4096|
|-|-|-|-|
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
|1024*1024|2048*2048|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
### Toon Shading
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
@@ -150,60 +94,32 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
#### Video Stylization
### Video Stylization
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
### Image Synthesis
### Chinese Models
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|FLUX|Stable Diffusion 3|
|1024x1024|2048x2048 (highres-fix)|
|-|-|
|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|![image_1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/2b6528cf-a229-46e9-b7dd-4a9475b07308)|![image_2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/11d264ec-966b-45c9-9804-74b60428b866)|
|Kolors|Hunyuan-DiT|
|-|-|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|Stable Diffusion|Stable Diffusion XL|
|Without LoRA|With LoRA|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)|
## Usage (in WebUI)
Create stunning images using the painter, with assistance from AI!
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
**This video is not rendered in real-time.**
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
* `Gradio` version
```
pip install gradio
```
```
python apps/gradio/DiffSynth_Studio.py
```
![20240822102002](https://github.com/user-attachments/assets/59613157-de51-4109-99b3-97cbffd88076)
* `Streamlit` version
```
pip install streamlit streamlit-drawable-canvas
```
```
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
python -m streamlit run DiffSynth_Studio.py
```
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954

View File

@@ -1,252 +0,0 @@
import gradio as gr
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
import os, torch
from PIL import Image
import numpy as np
config = {
"model_config": {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
"height": 512,
"width": 512,
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
def load_model_list(model_type):
if model_type is None:
return []
folder = config["model_config"][model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def load_model(model_type, model_path):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
while len(model_dict) + 1 > config["max_num_model_cache"]:
key = next(iter(model_dict.keys()))
model_manager_to_release, _ = model_dict[key]
model_manager_to_release.to("cpu")
del model_dict[key]
torch.cuda.empty_cache()
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
model_dict = {}
with gr.Blocks() as app:
gr.Markdown("# DiffSynth-Studio Painter")
with gr.Row():
with gr.Column(scale=382, min_width=100):
with gr.Accordion(label="Model"):
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
def model_type_to_model_path(model_type):
return gr.Dropdown(choices=load_model_list(model_type))
with gr.Accordion(label="Prompt"):
prompt = gr.Textbox(label="Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
with gr.Accordion(label="Image"):
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Column():
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
triggers=model_path.change
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
load_model(model_type, model_path)
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config["model_config"][model_type]["default_parameters"].get("height", height)
width = config["model_config"][model_type]["default_parameters"].get("width", width)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Painter"):
enable_local_prompt_list = []
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
label="Painter", key=f"canvas_{painter_layer_id}")
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
enable_local_prompt_list.append(enable_local_prompt)
local_prompt_list.append(local_prompt)
mask_scale_list.append(mask_scale)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
output_to_input_button = gr.Button(value="Set as input image")
painter_background = gr.State(None)
input_background = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
):
if enable_local_prompt:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
mask_scales.append(mask_scale)
input_params.update({
"local_prompts": local_prompts,
"masks": masks,
"mask_scales": mask_scales,
})
torch.manual_seed(seed)
image = pipe(**input_params)
return image
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(output_image, *canvas_list):
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = output_image.resize((w, h))
return tuple(canvas_list)
app.launch()

View File

@@ -1,6 +1,6 @@
from .data import *
from .models import *
from .prompters import *
from .prompts import *
from .schedulers import *
from .pipelines import *
from .controlnets import *

View File

@@ -1,358 +0,0 @@
from typing_extensions import Literal, TypeAlias
from ..models.sd_text_encoder import SDTextEncoder
from ..models.sd_unet import SDUNet
from ..models.sd_vae_encoder import SDVAEEncoder
from ..models.sd_vae_decoder import SDVAEDecoder
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from ..models.sdxl_unet import SDXLUNet
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from ..models.sd3_dit import SD3DiT
from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder
from ..models.sd_controlnet import SDControlNet
from ..models.sdxl_controlnet import SDXLControlNetUnion
from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel
from ..models.svd_image_encoder import SVDImageEncoder
from ..models.svd_unet import SVDUNet
from ..models.svd_vae_decoder import SVDVAEDecoder
from ..models.svd_vae_encoder import SVDVAEEncoder
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
]
preset_models_on_huggingface = {
"HunyuanDiT": [
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
"stable-video-diffusion-img2vid-xt": [
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
# FLUX
"FLUX.1-dev": [
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
"HunyuanDiT": [
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
# Stable Video Diffusion
"stable-video-diffusion-img2vid-xt": [
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
# ExVideo
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
"AingDiffusion_v12": [
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
],
"Flat2DAnimerge_v45Sharp": [
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
"ControlNet_union_sdxl_promax": [
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# RIFE
"RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
],
# Qwen Prompt
"QwenPrompt": [
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Omost prompt
"OmostPrompt":[
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
# Translator
"opus-mt-zh-en": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
# FLUX
"FLUX.1-dev": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
],
# RIFE
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
# CogVideo
"CogVideoX-5B": [
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",
"Flat2DAnimerge_v45Sharp",
"TextualInversion_VeryBadImageNegative_v1.3",
"StableDiffusionXL_v1",
"BluePencilXL_v200",
"StableDiffusionXL_Turbo",
"ControlNet_v11f1p_sd15_depth",
"ControlNet_v11p_sd15_softedge",
"ControlNet_v11f1e_sd15_tile",
"ControlNet_v11p_sd15_lineart",
"AnimateDiff_v2",
"AnimateDiff_xl_beta",
"RIFE",
"BeautifulPrompt",
"opus-mt-zh-en",
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
"CogVideoX-5B",
]

View File

@@ -23,14 +23,6 @@ class MultiControlNetManager:
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def cpu(self):
for model in self.models:
model.cpu()
def to(self, device):
for model in self.models:
model.to(device)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
@@ -45,14 +37,13 @@ class MultiControlNetManager:
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32, **kwargs
tiled=False, tile_size=64, tile_stride=32
):
res_stack = None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
processor_id=processor.processor_id
sample, timestep, encoder_hidden_states, conditioning,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:

View File

@@ -12,19 +12,19 @@ Processor_id: TypeAlias = Literal[
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to(device)
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to(device)
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to(device)
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "tile":
self.processor = None
else:

View File

@@ -1,35 +0,0 @@
import torch, os
from torchvision import transforms
import pandas as pd
from PIL import Image
class TextImageDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
self.steps_per_epoch = steps_per_epoch
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.image_processor = transforms.Compose(
[
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB")
image = self.image_processor(image)
return {"text": text, "image": image}
def __len__(self):
return self.steps_per_epoch

View File

@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
@@ -66,21 +66,6 @@ class RRDBNet(torch.nn.Module):
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
@staticmethod
def state_dict_converter():
return RRDBNetStateDictConverter()
class RRDBNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
return state_dict, {"upcast_to_float32": True}
class ESRGAN(torch.nn.Module):
def __init__(self, model):
@@ -88,8 +73,12 @@ class ESRGAN(torch.nn.Module):
self.model = model
@staticmethod
def from_model_manager(model_manager):
return ESRGAN(model_manager.fetch_model("esrgan"))
def from_pretrained(model_path):
model = RRDBNet()
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
model.load_state_dict(state_dict)
model.eval()
return ESRGAN(model)
def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)

View File

@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
class IFNet(nn.Module):
def __init__(self, **kwargs):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
@@ -99,8 +99,7 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
return flow_list, mask_list[2], merged
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return IFNetStateDictConverter()
@@ -113,7 +112,7 @@ class IFNetStateDictConverter:
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
return self.from_diffusers(state_dict)
class RIFEInterpolater:
@@ -125,7 +124,7 @@ class RIFEInterpolater:
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
def process_image(self, image):
width, height = image.size
@@ -203,7 +202,7 @@ class RIFESmoother(RIFEInterpolater):
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
output_tensor = []

View File

@@ -1 +1,482 @@
from .model_manager import *
import torch, os
from safetensors import safe_open
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .sd_lora import SDLoRA
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
self.torch_dtype = torch_dtype
self.device = device
self.model = {}
self.model_path = {}
self.textual_inversion_dict = {}
def is_stable_video_diffusion(self, state_dict):
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
return param_name in state_dict
def is_RIFE(self, state_dict):
param_name = "block_tea.convblock3.0.1.weight"
return param_name in state_dict or ("module." + param_name) in state_dict
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
if self.is_stabe_diffusion_xl(state_dict):
return False
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
return param_name in state_dict
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
return param_name in state_dict or param_name_2 in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
return param_name in state_dict
def is_animatediff_xl(self, state_dict):
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
return param_name in state_dict
def is_sd_lora(self, state_dict):
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
return param_name in state_dict
def is_translator(self, state_dict):
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 254
def is_ipadapter(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
def is_ipadapter_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 521
def is_ipadapter_xl(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
def is_ipadapter_xl_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 777
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
return param_name in state_dict
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_hunyuan_dit(self, state_dict):
param_name = "final_layer.adaLN_modulation.1.weight"
return param_name in state_dict
def is_diffusers_vae(self, state_dict):
param_name = "quant_conv.weight"
return param_name in state_dict
def is_ExVideo_StableVideoDiffusion(self, state_dict):
param_name = "blocks.185.positional_embedding.embeddings"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
"unet": SVDUNet,
"vae_decoder": SVDVAEDecoder,
"vae_encoder": SVDVAEEncoder,
}
if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "unet":
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
"unet": SDUNet,
"vae_decoder": SDVAEDecoder,
"vae_encoder": SDVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "text_encoder":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDXLTextEncoder,
"text_encoder_2": SDXLTextEncoder2,
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
if component in ["vae_decoder", "vae_encoder"]:
# These two model will output nan when float16 is enabled.
# The precision problem happens in the last three resnet blocks.
# I do not know how to solve this problem.
self.model[component].to(torch.float32).to(self.device)
else:
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_controlnet(self, state_dict, file_path=""):
component = "controlnet"
if component not in self.model:
self.model[component] = []
self.model_path[component] = []
model = SDControlNet()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component].append(model)
self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
component = "motion_modules"
model = SDMotionModel(add_positional_conv=add_positional_conv)
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_animatediff_xl(self, state_dict, file_path=""):
component = "motion_modules_xl"
model = SDXLMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
from transformers import AutoModelForCausalLM
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
).to(self.device).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_RIFE(self, state_dict, file_path=""):
component = "RIFE"
from ..extensions.RIFE import IFNet
model = IFNet().eval()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_sd_lora(self, state_dict, alpha):
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
def load_translator(self, state_dict, file_path=""):
# This model is lightweight, we do not place it on GPU.
component = "translator"
from transformers import AutoModelForSeq2SeqLM
model_folder = os.path.dirname(file_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter(self, state_dict, file_path=""):
component = "ipadapter"
model = SDIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_image_encoder"
model = IpAdapterCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl(self, state_dict, file_path=""):
component = "ipadapter_xl"
model = SDXLIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_xl_image_encoder"
model = IpAdapterXLCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_clip_text_encoder"
model = HunyuanDiTCLIPTextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_t5_text_encoder"
model = HunyuanDiTT5TextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit(self, state_dict, file_path=""):
component = "hunyuan_dit"
model = HunyuanDiT()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_diffusers_vae(self, state_dict, file_path=""):
# TODO: detect SD and SDXL
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
unet_state_dict = self.model["unet"].state_dict()
self.model["unet"].to("cpu")
del self.model["unet"]
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
self.model["unet"].load_state_dict(state_dict, strict=False)
self.model["unet"].to(self.torch_dtype).to(self.device)
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += self.search_for_embeddings(state_dict[k])
return embeddings
def load_textual_inversions(self, folder):
# Store additional tokens here
self.textual_inversion_dict = {}
# Load every textual inversion file
for file_name in os.listdir(folder):
if file_name.endswith(".txt"):
continue
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
# Search for embeddings
for embeddings in self.search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
break
def load_model(self, file_path, components=None, lora_alphas=[]):
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
elif self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
elif self.is_animatediff_xl(state_dict):
self.load_animatediff_xl(state_dict, file_path=file_path)
elif self.is_controlnet(state_dict):
self.load_controlnet(state_dict, file_path=file_path)
elif self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_sd_lora(state_dict):
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path)
elif self.is_ipadapter(state_dict):
self.load_ipadapter(state_dict, file_path=file_path)
elif self.is_ipadapter_image_encoder(state_dict):
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
elif self.is_ipadapter_xl(state_dict):
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit(state_dict):
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:
self.load_model(file_path, lora_alphas=lora_alphas)
def to(self, device):
for component in self.model:
if isinstance(self.model[component], list):
for model in self.model[component]:
model.to(device)
else:
self.model[component].to(device)
torch.cuda.empty_cache()
def get_model_with_model_path(self, model_path):
for component in self.model_path:
if isinstance(self.model_path[component], str):
if os.path.samefile(self.model_path[component], model_path):
return self.model[component]
elif isinstance(self.model_path[component], list):
for i, model_path_ in enumerate(self.model_path[component]):
if os.path.samefile(model_path_, model_path):
return self.model[component][i]
raise ValueError(f"Please load model {model_path} before you use it.")
def __getattr__(self, __name):
if __name in self.model:
return self.model[__name]
else:
return super.__getattribute__(__name)
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)

View File

@@ -1,395 +0,0 @@
import torch
from einops import rearrange, repeat
from .sd3_dit import TimestepEmbeddings
from .attention import Attention
from .utils import load_state_dict_from_folder
from .tiler import TileWorker2Dto3D
import numpy as np
class CogPatchify(torch.nn.Module):
def __init__(self, dim_in, dim_out, patch_size) -> None:
super().__init__()
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
return hidden_states
class CogAdaLayerNorm(torch.nn.Module):
def __init__(self, dim, dim_cond, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
def forward(self, hidden_states, prompt_emb, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
return hidden_states
else:
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
return hidden_states, prompt_emb, gate_a, gate_b
class CogDiTBlock(torch.nn.Module):
def __init__(self, dim, dim_cond, num_heads):
super().__init__()
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def apply_rotary_emb(self, x, freqs_cis):
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
q = self.norm_q(q)
k = self.norm_k(k)
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
return q, k, v
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
# Attention
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
hidden_states, prompt_emb, time_emb
)
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attention_io = self.attn1(
attention_io,
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
)
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
# Feed forward
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
hidden_states, prompt_emb, time_emb
)
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_io = self.ff(ff_io)
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
return hidden_states, prompt_emb
class CogDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.patchify = CogPatchify(16, 3072, 2)
self.time_embedder = TimestepEmbeddings(3072, 512)
self.context_embedder = torch.nn.Linear(4096, 3072)
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_3d_rotary_pos_embed(
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
):
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
):
grid_height = height // 2
grid_width = width // 2
base_size_width = 720 // (8 * 2)
base_size_height = 480 // (8 * 2)
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
embed_dim=64,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def build_mask(self, T, H, W, dtype, device, is_bound):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
B, C, T, H, W = hidden_states.shape
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = max(H - tile_size, 0), H
if w_ > W: w, w_ = max(W - tile_size, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
mask = self.build_mask(
value.shape[2], (hr-hl), (wr-wl),
hidden_states.dtype, hidden_states.device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
)
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
value[:, :, :, hl:hr, wl:wr] += model_output * mask
weight[:, :, :, hl:hr, wl:wr] += mask
value = value / weight
return value
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
if tiled:
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
model_input=hidden_states,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
)
num_frames, height, width = hidden_states.shape[-3:]
if image_rotary_emb is None:
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
hidden_states = self.patchify(hidden_states)
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return CogDiTStateDictConverter()
@staticmethod
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
model = CogDiT().to(torch_dtype)
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
model.load_state_dict(state_dict)
return model
class CogDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"patch_embed.proj.weight": "patchify.proj.weight",
"patch_embed.proj.bias": "patchify.proj.bias",
"patch_embed.text_proj.weight": "context_embedder.weight",
"patch_embed.text_proj.bias": "context_embedder.bias",
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
"norm_final.weight": "norm_final.weight",
"norm_final.bias": "norm_final.bias",
"norm_out.linear.weight": "norm_out.linear.weight",
"norm_out.linear.bias": "norm_out.linear.bias",
"norm_out.norm.weight": "norm_out.norm.weight",
"norm_out.norm.bias": "norm_out.norm.bias",
"proj_out.weight": "proj_out.weight",
"proj_out.bias": "proj_out.bias",
}
suffix_dict = {
"norm1.linear.weight": "norm1.linear.weight",
"norm1.linear.bias": "norm1.linear.bias",
"norm1.norm.weight": "norm1.norm.weight",
"norm1.norm.bias": "norm1.norm.bias",
"attn1.norm_q.weight": "norm_q.weight",
"attn1.norm_q.bias": "norm_q.bias",
"attn1.norm_k.weight": "norm_k.weight",
"attn1.norm_k.bias": "norm_k.bias",
"attn1.to_q.weight": "attn1.to_q.weight",
"attn1.to_q.bias": "attn1.to_q.bias",
"attn1.to_k.weight": "attn1.to_k.weight",
"attn1.to_k.bias": "attn1.to_k.bias",
"attn1.to_v.weight": "attn1.to_v.weight",
"attn1.to_v.bias": "attn1.to_v.bias",
"attn1.to_out.0.weight": "attn1.to_out.weight",
"attn1.to_out.0.bias": "attn1.to_out.bias",
"norm2.linear.weight": "norm2.linear.weight",
"norm2.linear.bias": "norm2.linear.bias",
"norm2.norm.weight": "norm2.norm.weight",
"norm2.norm.bias": "norm2.norm.bias",
"ff.net.0.proj.weight": "ff.0.weight",
"ff.net.0.proj.bias": "ff.0.bias",
"ff.net.2.weight": "ff.2.weight",
"ff.net.2.bias": "ff.2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "patch_embed.proj.weight":
param = param.unsqueeze(2)
state_dict_[rename_dict[name]] = param
else:
names = name.split(".")
if names[0] == "transformer_blocks":
suffix = ".".join(names[2:])
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,518 +0,0 @@
import torch
from einops import rearrange, repeat
from .tiler import TileWorker2Dto3D
class Downsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
class CogVideoXSpatialNorm3D(torch.nn.Module):
def __init__(self, f_channels, zq_channels, groups):
super().__init__()
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class Resnet3DBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
super().__init__()
self.nonlinearity = torch.nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
if in_channels != out_channels:
if use_conv_shortcut:
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
else:
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.conv_shortcut = lambda x: x
def forward(self, hidden_states, zq):
residual = hidden_states
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + self.conv_shortcut(residual)
return hidden_states
class CachedConv3d(torch.nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.cached_tensor = None
def clear_cache(self):
self.cached_tensor = None
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
if use_cache:
if self.cached_tensor is None:
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
input = torch.concat([self.cached_tensor, input], dim=2)
self.cached_tensor = input[:, :, -2:]
return super().forward(input)
class CogVAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Upsample3D(512, 512, compress_time=True),
Resnet3DBlock(512, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
])
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
sample = sample / self.scaling_factor
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.decode_small_video(x),
model_input=sample,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
progress_bar=progress_bar
)
else:
return self.decode_small_video(sample)
def decode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//2):
tl = i*2 + T%2 - (T%2 and i==0)
tr = i*2 + 2 + T%2
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEDecoderStateDictConverter()
class CogVAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Downsample3D(128, 128, compress_time=True),
Resnet3DBlock(128, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
])
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)[:, :16]
hidden_states = hidden_states * self.scaling_factor
return hidden_states
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.encode_small_video(x),
model_input=sample,
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
progress_bar=progress_bar
)
else:
return self.encode_small_video(sample)
def encode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//8):
t = i*8 + T%2 - (T%2 and i==0)
t_ = i*8 + 8 + T%2
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEEncoderStateDictConverter()
class CogVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"encoder.conv_in.conv.weight": "conv_in.weight",
"encoder.conv_in.conv.bias": "conv_in.bias",
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
"encoder.norm_out.weight": "norm_out.weight",
"encoder.norm_out.bias": "norm_out.bias",
"encoder.conv_out.conv.weight": "conv_out.weight",
"encoder.conv_out.conv.bias": "conv_out.bias",
}
prefix_dict = {
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
"encoder.mid_block.resnets.0.": "blocks.15.",
"encoder.mid_block.resnets.1.": "blocks.16.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
"norm1.weight": "norm1.weight",
"norm1.bias": "norm1.bias",
"norm2.weight": "norm2.weight",
"norm2.bias": "norm2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class CogVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"decoder.conv_in.conv.weight": "conv_in.weight",
"decoder.conv_in.conv.bias": "conv_in.bias",
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
"decoder.conv_out.conv.weight": "conv_out.weight",
"decoder.conv_out.conv.bias": "conv_out.bias"
}
prefix_dict = {
"decoder.mid_block.resnets.0.": "blocks.0.",
"decoder.mid_block.resnets.1.": "blocks.1.",
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,66 +0,0 @@
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files

View File

@@ -1,593 +0,0 @@
import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from einops import rearrange
from .tiler import TileWorker
class RoPEEmbedding(torch.nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
return hidden_states
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.only_out_a = only_out_a
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
batch_size = hidden_states_a.shape[0]
# Part A
qkv_a = self.a_to_qkv(hidden_states_a)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
# Part B
qkv_b = self.b_to_qkv(hidden_states_b)
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
q = torch.concat([q_b, q_a], dim=2)
k = torch.concat([k_b, k_a], dim=2)
v = torch.concat([v_b, v_a], dim=2)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv_a = self.a_to_qkv(hidden_states)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = dim // num_attention_heads
self.dim = dim
self.norm = AdaLayerNormSingle(dim)
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
self.proj_out = torch.nn.Linear(dim * 5, dim)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
q, k = self.norm_q_a(q), self.norm_k_a(k)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def tiled_forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=128, tile_stride=64,
**kwargs
):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
return self.tiled_forward(
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
"txt_in.bias": "context_embedder.bias",
"txt_in.weight": "context_embedder.weight",
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
"img_attn.proj.bias": "attn.a_to_out.bias",
"img_attn.proj.weight": "attn.a_to_out.weight",
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
"img_mlp.0.bias": "ff_a.0.bias",
"img_mlp.0.weight": "ff_a.0.weight",
"img_mlp.2.bias": "ff_a.2.bias",
"img_mlp.2.weight": "ff_a.2.weight",
"img_mod.lin.bias": "norm1_a.linear.bias",
"img_mod.lin.weight": "norm1_a.linear.weight",
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
"txt_attn.proj.bias": "attn.b_to_out.bias",
"txt_attn.proj.weight": "attn.b_to_out.weight",
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
"txt_mlp.0.bias": "ff_b.0.bias",
"txt_mlp.0.weight": "ff_b.0.weight",
"txt_mlp.2.bias": "ff_b.2.bias",
"txt_mlp.2.weight": "ff_b.2.weight",
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",
"modulation.lin.weight": "norm.linear.weight",
"norm.key_norm.scale": "norm_k_a.weight",
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
if name.startswith("final_layer.adaLN_modulation.1."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[rename] = param
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
else:
pass
return state_dict_

View File

@@ -1,93 +0,0 @@
import torch
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FluxTextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return embeds, pooled_embeds
@staticmethod
def state_dict_converter():
return FluxTextEncoder1StateDictConverter()
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
super().__init__(config)
self.eval()
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb
@staticmethod
def state_dict_converter():
return FluxTextEncoder2StateDictConverter()
class FluxTextEncoder1StateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = state_dict
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,303 +0,0 @@
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
class FluxVAEEncoder(SD3VAEEncoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEEncoderStateDictConverter()
class FluxVAEDecoder(SD3VAEDecoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEDecoderStateDictConverter()
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"encoder.conv_in.bias": "conv_in.bias",
"encoder.conv_in.weight": "conv_in.weight",
"encoder.conv_out.bias": "conv_out.bias",
"encoder.conv_out.weight": "conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"encoder.norm_out.bias": "conv_norm_out.bias",
"encoder.norm_out.weight": "conv_norm_out.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"decoder.conv_in.bias": "conv_in.bias",
"decoder.conv_in.weight": "conv_in.weight",
"decoder.conv_out.bias": "conv_out.bias",
"decoder.conv_out.weight": "conv_out.weight",
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"decoder.norm_out.bias": "conv_norm_out.bias",
"decoder.norm_out.weight": "conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -1,4 +1,5 @@
from .attention import Attention
from .tiler import TileWorker
from einops import repeat, rearrange
import math
import torch
@@ -398,8 +399,7 @@ class HunyuanDiT(torch.nn.Module):
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return HunyuanDiTStateDictConverter()

View File

@@ -79,8 +79,7 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return HunyuanDiTCLIPTextEncoderStateDictConverter()
@@ -132,8 +131,7 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return HunyuanDiTT5TextEncoderStateDictConverter()

File diff suppressed because one or more lines are too long

View File

@@ -1,252 +0,0 @@
import torch
from .sd_unet import SDUNet
from .sdxl_unet import SDXLUNet
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
class LoRAFromCivitai:
def __init__(self):
self.supported_model_classes = []
self.lora_prefix = []
self.renamed_lora_prefix = {}
self.special_keys = {}
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
for key in state_dict:
if ".lora_up" in key:
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
for special_key in self.special_keys:
target_name = target_name.replace(special_key, self.special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
target_name = target_name[len(lora_prefix):]
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
if model_resource == "diffusers":
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
elif model_resource == "civitai":
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
for model_resource in ["diffusers", "civitai"]:
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return lora_prefix, model_resource
except:
pass
return None
class SDLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDUNet, SDTextEncoder]
self.lora_prefix = ["lora_unet_", "lora_te_"]
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
}
class SDXLLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
self.renamed_lora_prefix = {"lora_te2_": "2"}
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "conditioner.embedders.0.transformer.text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
}
class FluxLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [FluxDiT, FluxDiT]
self.lora_prefix = ["lora_unet_", "transformer."]
self.renamed_lora_prefix = {}
self.special_keys = {
"single.blocks": "single_blocks",
"double.blocks": "double_blocks",
"img.attn": "img_attn",
"img.mlp": "img_mlp",
"img.mod": "img_mod",
"txt.attn": "txt_attn",
"txt.mlp": "txt_mlp",
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) > 0:
return "", ""
except:
pass
return None
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]

View File

@@ -1,471 +0,0 @@
import os, torch, hashlib, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List
from .downloader import download_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .lora import get_lora_loaders
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sdxl_controlnet import SDXLControlNetUnion
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
model.load_state_dict(model_state_dict)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
if os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(downloaded_files + file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None):
print(f"Loading models from: {file_path}")
if os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None):
for file_path in file_path_list:
self.load_model(file_path, model_names)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -1,798 +0,0 @@
import torch
from einops import rearrange
from .svd_unet import TemporalTimesteps
from .tiler import TileWorker
class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
super().__init__()
self.pos_embed_max_size = pos_embed_max_size
self.patch_size = patch_size
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
def cropped_pos_embed(self, height, width):
height = height // self.patch_size
width = width // self.patch_size
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
return spatial_pos_embed
def forward(self, latent):
height, width = latent.shape[-2:]
latent = self.proj(latent)
latent = latent.flatten(2).transpose(1, 2)
pos_embed = self.cropped_pos_embed(height, width)
return latent + pos_embed
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
def forward(self, timestep, dtype):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
return time_emb
class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift
return x
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class JointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.only_out_a = only_out_a
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
def forward(self, hidden_states_a, hidden_states_b):
batch_size = hidden_states_a.shape[0]
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class JointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class JointTransformerFinalBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim, single=True)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
return hidden_states_a, hidden_states_b
class SD3DiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
self.time_embedder = TimestepEmbeddings(256, 1536)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
self.context_embedder = torch.nn.Linear(4096, 1536)
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
self.norm_out = AdaLayerNorm(1536, single=True)
self.proj_out = torch.nn.Linear(1536, 64)
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
if tiled:
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
prompt_emb = self.context_embedder(prompt_emb)
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embedder(hidden_states)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
hidden_states = self.norm_out(hidden_states, conditioning)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
@staticmethod
def state_dict_converter():
return SD3DiTStateDictConverter()
class SD3DiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"context_embedder": "context_embedder",
"pos_embed.pos_embed": "pos_embedder.pos_embed",
"pos_embed.proj": "pos_embedder.proj",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "norm_out.linear",
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "pos_embed.pos_embed":
param = param.reshape((1, 192, 192, 1536))
state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in rename_dict:
state_dict_[rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name == "model.diffusion_model.pos_embed":
param = param.reshape((1, 192, 192, 1536))
if isinstance(rename_dict[name], str):
state_dict_[rename_dict[name]] = param
else:
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
state_dict_[name_] = param
return state_dict_

File diff suppressed because it is too large Load Diff

View File

@@ -1,81 +0,0 @@
import torch
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
from .sd_unet import ResnetBlock, UpSampler
from .tiler import TileWorker
class SD3VAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 1.5305 # Different from SD 1.x
self.shift_factor = 0.0609 # Different from SD 1.x
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
UpSampler(256),
# UpDecoderBlock2D
ResnetBlock(256, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = sample / self.scaling_factor + self.shift_factor
hidden_states = self.conv_in(hidden_states)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
@staticmethod
def state_dict_converter():
return SDVAEDecoderStateDictConverter()

View File

@@ -1,95 +0,0 @@
import torch
from .sd_unet import ResnetBlock, DownSampler
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
from .tiler import TileWorker
from einops import rearrange
class SD3VAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 1.5305 # Different from SD 1.x
self.shift_factor = 0.0609 # Different from SD 1.x
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# DownEncoderBlock2D
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
DownSampler(128, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(128, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
DownSampler(256, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(256, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
DownSampler(512, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = self.conv_in(sample)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states[:, :16]
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
return hidden_states
def encode_video(self, sample, batch_size=8):
B = sample.shape[0]
hidden_states = []
for i in range(0, sample.shape[2], batch_size):
j = min(i + batch_size, sample.shape[2])
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
hidden_states_batch = self(sample_batch)
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
hidden_states.append(hidden_states_batch)
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
@staticmethod
def state_dict_converter():
return SDVAEEncoderStateDictConverter()

View File

@@ -97,10 +97,9 @@ class SDControlNet(torch.nn.Module):
self,
sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32,
**kwargs
):
# 1. time
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
@@ -135,8 +134,7 @@ class SDControlNet(torch.nn.Module):
return controlnet_res_stack
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDControlNetStateDictConverter()

View File

@@ -29,7 +29,7 @@ class SDIpAdapter(torch.nn.Module):
def set_less_adapter(self):
# IP-Adapter for SD v1.5 doesn't support this feature.
self.set_full_adapter()
self.set_full_adapter(self)
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
@@ -47,8 +47,7 @@ class SDIpAdapter(torch.nn.Module):
}
return ip_kv_dict
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDIpAdapterStateDictConverter()

View File

@@ -0,0 +1,60 @@
import torch
from .sd_unet import SDUNetStateDictConverter, SDUNet
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
class SDLoRA:
def __init__(self):
pass
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
}
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
for special_key in special_keys:
target_name = target_name.replace(special_key, special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_unet = unet.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_unet[name] += state_dict_lora[name].to(device=device)
unet.load_state_dict(state_dict_unet)
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_text_encoder = text_encoder.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
text_encoder.load_state_dict(state_dict_text_encoder)

View File

@@ -1,20 +1,28 @@
from .sd_unet import SDUNet, Attention, GEGLU
from .svd_unet import get_timestep_embedding
import torch
from einops import rearrange, repeat
class TemporalTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
super().__init__()
self.add_positional_conv = add_positional_conv
# 1. Self-Attn
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe1 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 2. Cross-Attn
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe2 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
self.ff = torch.nn.Linear(dim * 4, dim)
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.pe1.shape[1]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, batch_size=1):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
class TemporalBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim
attention_head_dim,
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
add_positional_conv=add_positional_conv
)
for d in range(num_layers)
])
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
class SDMotionModel(torch.nn.Module):
def __init__(self):
def __init__(self, add_positional_conv=None):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
])
self.call_block_id = {
1: 0,
@@ -144,8 +182,7 @@ class SDMotionModel(torch.nn.Module):
def forward(self):
pass
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDMotionModelStateDictConverter()
@@ -153,7 +190,42 @@ class SDMotionModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
for i in range(21):
# Extend positional embedding
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
state_dict[name] = state_dict[name][:, ids]
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
state_dict[name] = state_dict[name][:, ids]
# add post convolution
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
return state_dict
def from_diffusers(self, state_dict, add_positional_conv=None):
rename_dict = {
"norm": "norm",
"proj_in": "proj_in",
@@ -193,7 +265,9 @@ class SDMotionModelStateDictConverter:
else:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
state_dict_[rename] = state_dict[name]
if add_positional_conv is not None:
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
def from_civitai(self, state_dict, add_positional_conv=None):
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)

View File

@@ -0,0 +1,115 @@
from .attention import Attention
from .svd_unet import get_timestep_embedding
import torch
from einops import rearrange, repeat
class ExVideoMotionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
super().__init__()
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
self.positional_embedding = torch.nn.Parameter(emb)
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.positional_embedding.shape[0]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
batch, inner_dim, height, width = hidden_states.shape
residual = hidden_states
pos_emb = self.positional_ids(batch // batch_size)
pos_emb = self.positional_embedding[pos_emb]
pos_emb = pos_emb.repeat(batch_size)
hidden_states = hidden_states + pos_emb
if self.positional_conv is not None:
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
hidden_states = self.positional_conv(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
else:
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
for norm, attn in zip(self.norms, self.attns):
norm_hidden_states = norm(hidden_states)
attn_output = attn(norm_hidden_states)
hidden_states = hidden_states + attn_output
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class ExVideoMotionModel(torch.nn.Module):
def __init__(self, num_layers=2):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
])
self.call_block_id = {
1: 0,
4: 1,
9: 2,
12: 3,
17: 4,
20: 5,
24: 6,
26: 7,
29: 8,
32: 9,
34: 10,
36: 11,
40: 12,
43: 13,
46: 14,
50: 15,
53: 16,
56: 17,
60: 18,
63: 19,
66: 20
}
def forward(self):
pass
def state_dict_converter(self):
pass

View File

@@ -71,8 +71,7 @@ class SDTextEncoder(torch.nn.Module):
embeds = self.final_layer_norm(embeds)
return embeds
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDTextEncoderStateDictConverter()

View File

@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# 1. time
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
# 2. pre-process
@@ -342,8 +342,7 @@ class SDUNet(torch.nn.Module):
return hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDUNetStateDictConverter()

View File

@@ -90,8 +90,6 @@ class SDVAEDecoder(torch.nn.Module):
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
original_dtype = sample.dtype
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
@@ -112,12 +110,10 @@ class SDVAEDecoder(torch.nn.Module):
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states.to(original_dtype)
return hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDVAEDecoderStateDictConverter()

View File

@@ -50,8 +50,6 @@ class SDVAEEncoder(torch.nn.Module):
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
original_dtype = sample.dtype
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
@@ -73,7 +71,6 @@ class SDVAEEncoder(torch.nn.Module):
hidden_states = self.quant_conv(hidden_states)
hidden_states = hidden_states[:, :4]
hidden_states *= self.scaling_factor
hidden_states = hidden_states.to(original_dtype)
return hidden_states
@@ -94,8 +91,7 @@ class SDVAEEncoder(torch.nn.Module):
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDVAEEncoderStateDictConverter()

View File

@@ -1,318 +0,0 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .sdxl_unet import SDXLUNet
from .tiler import TileWorker
from .sd_controlnet import ControlNetConditioningLayer
from collections import OrderedDict
class QuickGELU(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(torch.nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
self.ln_1 = torch.nn.LayerNorm(d_model)
self.mlp = torch.nn.Sequential(OrderedDict([
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", torch.nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = torch.nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class SDXLControlNetUnion(torch.nn.Module):
def __init__(self, global_pool=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.add_time_proj = Timesteps(256)
self.add_time_embedding = torch.nn.Sequential(
torch.nn.Linear(2816, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.control_type_proj = Timesteps(256)
self.control_type_embedding = torch.nn.Sequential(
torch.nn.Linear(256 * 8, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
self.spatial_ch_projs = torch.nn.Linear(320, 320)
self.blocks = torch.nn.ModuleList([
# DownBlock2D
ResnetBlock(320, 320, 1280),
PushBlock(),
ResnetBlock(320, 320, 1280),
PushBlock(),
DownSampler(320),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(320, 640, 1280),
AttentionBlock(10, 64, 640, 2, 2048),
PushBlock(),
ResnetBlock(640, 640, 1280),
AttentionBlock(10, 64, 640, 2, 2048),
PushBlock(),
DownSampler(640),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(640, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
PushBlock(),
# UNetMidBlock2DCrossAttn
ResnetBlock(1280, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
ResnetBlock(1280, 1280, 1280),
PushBlock()
])
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
])
self.global_pool = global_pool
# 0 -- openpose
# 1 -- depth
# 2 -- hed/pidi/scribble/ted
# 3 -- canny/lineart/anime_lineart/mlsd
# 4 -- normal
# 5 -- segment
# 6 -- tile
# 7 -- repaint
self.task_id = {
"openpose": 0,
"depth": 1,
"softedge": 2,
"canny": 3,
"lineart": 3,
"lineart_anime": 3,
"tile": 6,
"inpaint": 7
}
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
controlnet_cond = self.controlnet_conv_in(conditioning)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[task_id]
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
x = self.controlnet_transformer(x)
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser = controlnet_cond + alpha
hidden_states = hidden_states + controlnet_cond_fuser
return hidden_states
def forward(
self,
sample, timestep, encoder_hidden_states,
conditioning, processor_id, add_time_id, add_text_embeds,
tiled=False, tile_size=64, tile_stride=32,
unet:SDXLUNet=None,
**kwargs
):
task_id = self.task_id[processor_id]
# 1. time
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(sample.dtype)
if unet is not None and unet.is_kolors:
add_embeds = unet.add_time_embedding(add_embeds)
else:
add_embeds = self.add_time_embedding(add_embeds)
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
control_type[:, task_id] = 1
control_embeds = self.control_type_proj(control_type.flatten())
control_embeds = control_embeds.reshape((sample.shape[0], -1))
control_embeds = control_embeds.to(sample.dtype)
control_embeds = self.control_type_embedding(control_embeds)
time_emb = t_emb + add_embeds + control_embeds
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
text_emb = encoder_hidden_states
if unet is not None and unet.is_kolors:
text_emb = unet.text_intermediate_proj(text_emb)
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
# pool
if self.global_pool:
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return SDXLControlNetUnionStateDictConverter()
class SDXLControlNetUnionStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
]
# controlnet_rename_dict
controlnet_rename_dict = {
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
}
# Rename each parameter
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
pass
elif name in controlnet_rename_dict:
names = controlnet_rename_dict[name].split(".")
elif names[0] == "controlnet_down_blocks":
names[0] = "controlnet_blocks"
elif names[0] == "controlnet_mid_block":
names = ["controlnet_blocks", "9", names[-1]]
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
elif names[0] == "control_add_embedding":
names[0] = "control_type_embedding"
elif names[0] == "transformer_layes":
names[0] = "controlnet_transformer"
names.pop(1)
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
if names[0] == "mid_block":
names.insert(1, "0")
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
block_type_with_id = ".".join(names[:4])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:4])
names = ["blocks", str(block_id[block_type])] + names[4:]
if "ff" in names:
ff_index = names.index("ff")
component = ".".join(names[ff_index:ff_index+3])
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
names = names[:ff_index] + [component] + names[ff_index+3:]
if "to_out" in names:
names.pop(names.index("to_out") + 1)
else:
print(name, state_dict[name].shape)
# raise ValueError(f"Unknown parameters: {name}")
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if name not in rename_dict:
continue
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -96,8 +96,7 @@ class SDXLIpAdapter(torch.nn.Module):
}
return ip_kv_dict
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDXLIpAdapterStateDictConverter()

View File

@@ -49,8 +49,7 @@ class SDXLMotionModel(torch.nn.Module):
def forward(self):
pass
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDMotionModelStateDictConverter()

View File

@@ -36,8 +36,7 @@ class SDXLTextEncoder(torch.nn.Module):
break
return embeds
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDXLTextEncoderStateDictConverter()
@@ -81,8 +80,7 @@ class SDXLTextEncoder2(torch.nn.Module):
pooled_embeds = self.text_projection(pooled_embeds)
return pooled_embeds, hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDXLTextEncoder2StateDictConverter()

View File

@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
class SDXLUNet(torch.nn.Module):
def __init__(self, is_kolors=False):
def __init__(self):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
@@ -13,12 +13,11 @@ class SDXLUNet(torch.nn.Module):
)
self.add_time_proj = Timesteps(256)
self.add_time_embedding = torch.nn.Sequential(
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
torch.nn.Linear(2816, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
self.blocks = torch.nn.ModuleList([
# DownBlock2D
@@ -83,17 +82,13 @@ class SDXLUNet(torch.nn.Module):
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
self.is_kolors = is_kolors
def forward(
self,
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
tiled=False, tile_size=64, tile_stride=8,
use_gradient_checkpointing=False,
**kwargs
tiled=False, tile_size=64, tile_stride=8, **kwargs
):
# 1. time
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
@@ -107,22 +102,11 @@ class SDXLUNet(torch.nn.Module):
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, block in enumerate(self.blocks):
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
@@ -135,8 +119,7 @@ class SDXLUNet(torch.nn.Module):
return hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDXLUNetStateDictConverter()
@@ -165,8 +148,6 @@ class SDXLUNetStateDictConverter:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
pass
elif names[0] in ["encoder_hid_proj"]:
names[0] = "text_intermediate_proj"
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
@@ -200,9 +181,6 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_
def from_civitai(self, state_dict):
@@ -1895,7 +1873,4 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_

View File

@@ -2,23 +2,14 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
class SDXLVAEDecoder(SDVAEDecoder):
def __init__(self, upcast_to_float32=True):
def __init__(self):
super().__init__()
self.scaling_factor = 0.13025
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDXLVAEDecoderStateDictConverter()
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}

View File

@@ -2,23 +2,14 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
class SDXLVAEEncoder(SDVAEEncoder):
def __init__(self, upcast_to_float32=True):
def __init__(self):
super().__init__()
self.scaling_factor = 0.13025
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SDXLVAEEncoderStateDictConverter()
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}

View File

@@ -44,8 +44,7 @@ class SVDImageEncoder(torch.nn.Module):
embeds = self.visual_projection(embeds)
return embeds
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SVDImageEncoderStateDictConverter()

View File

@@ -407,8 +407,7 @@ class SVDUNet(torch.nn.Module):
return hidden_states
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SVDUNetStateDictConverter()

View File

@@ -199,8 +199,7 @@ class SVDVAEDecoder(torch.nn.Module):
return values
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SVDVAEDecoderStateDictConverter()

View File

@@ -6,8 +6,7 @@ class SVDVAEEncoder(SDVAEEncoder):
super().__init__()
self.scaling_factor = 0.13025
@staticmethod
def state_dict_converter():
def state_dict_converter(self):
return SVDVAEEncoderStateDictConverter()

View File

@@ -104,77 +104,3 @@ class TileWorker:
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output
class TileWorker2Dto3D:
"""
Process 3D tensors, but only enable TileWorker on 2D.
"""
def __init__(self):
pass
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4 if border_width is None else border_width
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(
self,
forward_fn,
model_input,
tile_size, tile_stride,
tile_device="cpu", tile_dtype=torch.float32,
computation_device="cuda", computation_dtype=torch.float32,
border_width=None, scales=[1, 1, 1, 1],
progress_bar=lambda x:x
):
B, C, T, H, W = model_input.shape
scale_C, scale_T, scale_H, scale_W = scales
tile_size_H, tile_size_W = tile_size
tile_stride_H, tile_stride_W = tile_stride
value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride_H):
for w in range(0, W, tile_stride_W):
if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
continue
h_, w_ = h + tile_size_H, w + tile_size_W
if h_ > H: h, h_ = max(H - tile_size_H, 0), H
if w_ > W: w, w_ = max(W - tile_size_W, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in progress_bar(tasks):
mask = self.build_mask(
int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
tile_dtype, tile_device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
border_width=border_width
)
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
value = value / weight
return value

View File

@@ -1,96 +0,0 @@
import torch, os
from safetensors import safe_open
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files

View File

@@ -1,11 +1,6 @@
from .sd_image import SDImagePipeline
from .sd_video import SDVideoPipeline
from .sdxl_image import SDXLImagePipeline
from .sdxl_video import SDXLVideoPipeline
from .sd3_image import SD3ImagePipeline
from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
from .pipeline_runner import SDVideoPipelineRunner
KolorsImagePipeline = SDXLImagePipeline
from .stable_diffusion import SDImagePipeline
from .stable_diffusion_xl import SDXLImagePipeline
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
from .stable_diffusion_xl_video import SDXLVideoPipeline
from .stable_video_diffusion import SVDVideoPipeline
from .hunyuan_dit import HunyuanDiTImagePipeline

View File

@@ -1,87 +0,0 @@
import torch
import numpy as np
from PIL import Image
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.cpu_offload = False
self.model_names = []
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video
def merge_latents(self, value, latents, masks, scales):
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1)
value[mask] += latent[mask] * scale
weight[mask] += scale
value /= weight
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback):
noise_pred_global = inference_callback(prompt_emb_global)
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()

View File

@@ -1,131 +0,0 @@
from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
from ..prompters import CogPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
class CogVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
self.prompter = CogPrompter()
# models
self.text_encoder: FluxTextEncoder2 = None
self.dit: CogDiT = None
self.vae_encoder: CogVAEEncoder = None
self.vae_decoder: CogVAEDecoder = None
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("cog_dit")
self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = CogVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
return {"prompt_emb": prompt_emb}
def prepare_extra_input(self, latents):
return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
cfg_scale=7.0,
denoising_strength=1.0,
num_frames=49,
height=480,
width=720,
num_inference_steps=20,
tiled=False,
tile_size=(60, 90),
tile_stride=(30, 45),
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Prepare latent tensors
noise = torch.randn((1, 16, num_frames // 4 + 1, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if denoising_strength == 1.0:
latents = noise.clone()
else:
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
if not tiled: latents = latents.to(self.device)
# Encode prompt
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update progress bar
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
video = self.tensor2video(video[0])
return video

View File

@@ -22,10 +22,6 @@ def lets_dance(
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
@@ -54,7 +50,7 @@ def lets_dance(
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep).to(sample.dtype)
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
@@ -136,40 +132,8 @@ def lets_dance_xl(
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
# 1. ControlNet
controlnet_insert_block_id = 22
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
add_time_id=add_time_id,
add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
t_emb = unet.time_embedding(t_emb)
time_embeds = unet.add_time_proj(add_time_id)
@@ -183,36 +147,16 @@ def lets_dance_xl(
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 4. blocks
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
@@ -221,10 +165,6 @@ def lets_dance_xl(
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)

View File

@@ -1,155 +0,0 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler()
self.prompter = FluxPrompter()
# models
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, positive=True):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
def prepare_extra_input(self, latents=None, guidance=0.0):
latent_image_ids = self.dit.prepare_image_ids(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts= None,
masks= None,
mask_scales= None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=0.0,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=30,
tiled=False,
tile_size=128,
tile_stride=64,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -3,11 +3,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models import ModelManager
from ..prompters import HunyuanDiTPrompter
from ..prompts import HunyuanDiTPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
@@ -122,12 +122,14 @@ class ImageSizeManager:
class HunyuanDiTImagePipeline(BasePipeline):
class HunyuanDiTImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
super().__init__()
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
self.prompter = HunyuanDiTPrompter()
self.device = device
self.torch_dtype = torch_dtype
self.image_size_manager = ImageSizeManager()
# models
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
@@ -135,63 +137,44 @@ class HunyuanDiTImagePipeline(BasePipeline):
self.dit: HunyuanDiT = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.dit
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
self.dit = model_manager.hunyuan_dit
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
self.dit = model_manager.fetch_model("hunyuan_dit")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
def from_model_manager(model_manager: ModelManager):
pipe = HunyuanDiTImagePipeline(
device=model_manager.device if device is None else device,
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=positive,
device=self.device
)
return {
"text_emb": text_emb,
"text_emb_mask": text_emb_mask,
"text_emb_t5": text_emb_t5,
"text_emb_mask_t5": text_emb_mask_t5
}
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
if tiled:
height, width = tile_size * 16, tile_size * 16
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
@@ -210,14 +193,12 @@ class HunyuanDiTImagePipeline(BasePipeline):
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=1,
input_image=None,
reference_images=[],
reference_strengths=[0.4],
denoising_strength=1.0,
height=1024,
@@ -235,36 +216,71 @@ class HunyuanDiTImagePipeline(BasePipeline):
# Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise.clone()
# Prepare reference latents
reference_latents = []
for reference_image in reference_images:
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
# Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_t5,
prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=True,
device=self.device
)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_t5,
negative_prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=False,
device=self.device
)
# Prepare positional id
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
# In-context reference
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
if progress_id < num_inference_steps * reference_strength:
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
self.dit(
noisy_reference_latents,
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
timestep,
**extra_input,
to_cache=True
)
# Positive side
inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_posi = self.dit(
latents,
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
timestep,
**extra_input,
)
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
latents,
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
timestep,
**extra_input
)
# Classifier-free guidance
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
@@ -277,9 +293,6 @@ class HunyuanDiTImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -1,105 +0,0 @@
import os, torch, json
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
from ..processors.sequencial_processor import SequencialProcessor
from ..data import VideoData, save_frames, save_video
class SDVideoPipelineRunner:
def __init__(self, in_streamlit=False):
self.in_streamlit = in_streamlit
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_models(model_list)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
textual_inversion_paths = []
for file_name in os.listdir(textual_inversion_folder):
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
pipe.prompter.load_textual_inversions(textual_inversion_paths)
return model_manager, pipe
def load_smoother(self, model_manager, smoother_configs):
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
return smoother
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
torch.manual_seed(seed)
if self.in_streamlit:
import streamlit as st
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
else:
output_video = pipe(**pipeline_inputs, smoother=smoother)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
if start_frame_id is None:
start_frame_id = 0
if end_frame_id is None:
end_frame_id = len(video)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
if self.in_streamlit:
import streamlit as st
if self.in_streamlit: st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Loading videos ... done!")
if self.in_streamlit: st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
if self.in_streamlit: st.markdown("Loading models ... done!")
if "smoother_configs" in config:
if self.in_streamlit: st.markdown("Loading smoother ...")
smoother = self.load_smoother(model_manager, config["smoother_configs"])
if self.in_streamlit: st.markdown("Loading smoother ... done!")
else:
smoother = None
if self.in_streamlit: st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
if self.in_streamlit: st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
if self.in_streamlit: st.markdown("Saving videos ... done!")
if self.in_streamlit: st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
if self.in_streamlit: st.video(video_file.read())

View File

@@ -1,143 +0,0 @@
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
from ..prompters import SD3Prompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
class SD3ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler()
self.prompter = SD3Prompter()
# models
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: SD3TextEncoder2 = None
self.text_encoder_3: SD3TextEncoder3 = None
self.dit: SD3DiT = None
self.vae_decoder: SD3VAEDecoder = None
self.vae_encoder: SD3VAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
self.dit = model_manager.fetch_model("sd3_dit")
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = SD3ImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, positive=True):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=128,
tile_stride=64,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,188 +0,0 @@
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
class SDImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDPrompter()
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self):
return self.unet
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
self.unet = model_manager.fetch_model("sd_unet")
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sd_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
return {"encoder_hidden_states": prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,266 +0,0 @@
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .sd_image import SDImagePipeline
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
ipadapter_kwargs_list = {},
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
tiled=False,
tile_size=64,
tile_stride=32,
device="cuda",
animatediff_batch_size=16,
animatediff_stride=8,
):
num_frames = sample.shape[0]
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
for batch_id in range(0, num_frames, animatediff_stride):
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
# process this batch
hidden_states_batch = lets_dance(
unet, motion_modules, controlnet,
sample[batch_id: batch_id_].to(device),
timestep,
encoder_hidden_states,
ipadapter_kwargs_list=ipadapter_kwargs_list,
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
).cpu()
# update hidden_states
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
hidden_states, num = hidden_states_output[i]
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
hidden_states_output[i] = (hidden_states, num + bias)
if batch_id_ == num_frames:
break
# output
hidden_states = torch.stack([h for h, _ in hidden_states_output])
return hidden_states
class SDVideoPipeline(SDImagePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
self.prompter = SDPrompter()
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
self.motion_modules: SDMotionModel = None
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
self.unet = model_manager.fetch_model("sd_unet")
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sd_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
# Motion Modules
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents.append(latent.cpu())
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
other_kwargs = {
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_video(input_frames, **tiler_kwargs)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
device=self.device,
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_video(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_video(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
output_frames = self.decode_video(latents, **tiler_kwargs)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return output_frames

View File

@@ -1,221 +0,0 @@
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.kolors_text_encoder import ChatGLMModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDXLPrompter, KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
class SDXLImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.text_encoder_kolors: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self):
return self.unet
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
self.unet = model_manager.fetch_model("sdxl_unet")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sdxl_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
# Kolors
if self.text_encoder_kolors is not None:
print("Switch to Kolors. The prompter and scheduler will be replaced.")
self.prompter = KolorsPrompter()
self.prompter.fetch_models(self.text_encoder_kolors)
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
else:
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDXLImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=positive,
)
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
def prepare_extra_input(self, latents=None):
height, width = latents.shape[2] * 8, latents.shape[3] * 8
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: lets_dance_xl(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -1,223 +0,0 @@
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
from ..models.kolors_text_encoder import ChatGLMModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDXLPrompter, KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .sdxl_image import SDXLImagePipeline
from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
class SDXLVideoPipeline(SDXLImagePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
self.prompter = SDXLPrompter()
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.text_encoder_kolors: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# self.controlnet: MultiControlNetManager = None (TODO)
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
self.motion_modules: SDXLMotionModel = None
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
self.unet = model_manager.fetch_model("sdxl_unet")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets (TODO)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
# Motion Modules
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
# Kolors
if self.text_encoder_kolors is not None:
print("Switch to Kolors. The prompter will be replaced.")
self.prompter = KolorsPrompter()
self.prompter.fetch_models(self.text_encoder_kolors)
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
else:
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDXLVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents.append(latent.cpu())
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_video(input_frames, **tiler_kwargs)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
latents = latents.to(self.device) # will be deleted for supporting long videos
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
device=self.device,
)
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, timestep=timestep,
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_video(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_video(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
output_frames = self.decode_video(latents, **tiler_kwargs)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return output_frames

View File

@@ -0,0 +1,167 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id),
model_manager.get_model_with_model_path(config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter" in model_manager.model:
self.ipadapter = model_manager.ipadapter
if "ipadapter_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
pipe.fetch_ipadapter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Prepare ControlNets
if controlnet_image is not None:
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
device=self.device, vram_limit_level=0
)
noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
device=self.device, vram_limit_level=0
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,371 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from ..data import VideoData, save_frames, save_video
from .dancer import lets_dance
from ..processors.sequencial_processor import SequencialProcessor
from typing import List
import torch, os, json
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
controlnet_frames = None,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
device = "cuda",
vram_limit_level = 0,
):
num_frames = sample.shape[0]
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
for batch_id in range(0, num_frames, animatediff_stride):
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
# process this batch
hidden_states_batch = lets_dance(
unet, motion_modules, controlnet,
sample[batch_id: batch_id_].to(device),
timestep,
encoder_hidden_states[batch_id: batch_id_].to(device),
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=device, vram_limit_level=vram_limit_level
).cpu()
# update hidden_states
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
hidden_states, num = hidden_states_output[i]
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
hidden_states_output[i] = (hidden_states, num + bias)
if batch_id_ == num_frames:
break
# output
hidden_states = torch.stack([h for h, _ in hidden_states_output])
return hidden_states
class SDVideoPipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
self.prompter = SDPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.motion_modules: SDMotionModel = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id),
model_manager.get_model_with_model_path(config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_motion_modules(self, model_manager: ModelManager):
if "motion_modules" in model_manager.model:
self.motion_modules = model_manager.motion_modules
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
use_animatediff="motion_modules" in model_manager.model
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
latents.append(latent)
latents = torch.concat(latents, dim=0)
return latents
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
vram_limit_level=0,
post_normalize=False,
contrast_enhance_scale=1.0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_images(input_frames)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_images(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_images(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
output_frames = self.decode_images(latents)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return output_frames
class SDVideoPipelineRunner:
def __init__(self, in_streamlit=False):
self.in_streamlit = in_streamlit
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_textual_inversions(textual_inversion_folder)
model_manager.load_models(model_list, lora_alphas=lora_alphas)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
return model_manager, pipe
def load_smoother(self, model_manager, smoother_configs):
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
return smoother
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
torch.manual_seed(seed)
if self.in_streamlit:
import streamlit as st
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
else:
output_video = pipe(**pipeline_inputs, smoother=smoother)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
if start_frame_id is None:
start_frame_id = 0
if end_frame_id is None:
end_frame_id = len(video)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
if self.in_streamlit:
import streamlit as st
if self.in_streamlit: st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Loading videos ... done!")
if self.in_streamlit: st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
if self.in_streamlit: st.markdown("Loading models ... done!")
if "smoother_configs" in config:
if self.in_streamlit: st.markdown("Loading smoother ...")
smoother = self.load_smoother(model_manager, config["smoother_configs"])
if self.in_streamlit: st.markdown("Loading smoother ... done!")
else:
smoother = None
if self.in_streamlit: st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
if self.in_streamlit: st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
if self.in_streamlit: st.markdown("Saving videos ... done!")
if self.in_streamlit: st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
if self.in_streamlit: st.video(video_file.read())

View File

@@ -0,0 +1,175 @@
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
# TODO: SDXL ControlNet
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance_xl
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDXLImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
# TODO: SDXL ControlNet
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.text_encoder_2 = model_manager.text_encoder_2
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
# TODO: SDXL ControlNet
pass
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter_xl" in model_manager.model:
self.ipadapter = model_manager.ipadapter_xl
if "ipadapter_xl_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
pipe = SDXLImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
pipe.fetch_ipadapter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,190 @@
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
from .dancer import lets_dance_xl
# TODO: SDXL ControlNet
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDXLVideoPipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
self.prompter = SDXLPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# TODO: SDXL ControlNet
self.motion_modules: SDXLMotionModel = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.text_encoder_2 = model_manager.text_encoder_2
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
# TODO: SDXL ControlNet
pass
def fetch_motion_modules(self, model_manager: ModelManager):
if "motion_modules_xl" in model_manager.model:
self.motion_modules = model_manager.motion_modules_xl
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
pipe = SDXLVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
use_animatediff="motion_modules_xl" in model_manager.model
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
latents.append(latent)
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
num_frames=None,
input_frames=None,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
vram_limit_level=0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_images(input_frames)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_images(latents.to(torch.float32))
return image

View File

@@ -1,6 +1,5 @@
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
from ..schedulers import ContinuousODEScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
@@ -9,11 +8,13 @@ from einops import rearrange, repeat
class SVDVideoPipeline(BasePipeline):
class SVDVideoPipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
super().__init__()
self.scheduler = ContinuousODEScheduler()
self.device = device
self.torch_dtype = torch_dtype
# models
self.image_encoder: SVDImageEncoder = None
self.unet: SVDUNet = None
@@ -21,23 +22,32 @@ class SVDVideoPipeline(BasePipeline):
self.vae_decoder: SVDVAEDecoder = None
def fetch_models(self, model_manager: ModelManager):
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
self.unet = model_manager.fetch_model("svd_unet")
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
def fetch_main_models(self, model_manager: ModelManager):
self.image_encoder = model_manager.image_encoder
self.unet = model_manager.unet
self.vae_encoder = model_manager.vae_encoder
self.vae_decoder = model_manager.vae_decoder
@staticmethod
def from_model_manager(model_manager: ModelManager, **kwargs):
pipe = SVDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager)
pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
pipe.fetch_main_models(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def encode_image_with_clip(self, image):
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))

View File

@@ -1,9 +0,0 @@
from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter
from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter
from .flux_prompter import FluxPrompter
from .omost import OmostPromter
from .cog_prompter import CogPrompter

View File

@@ -1,70 +0,0 @@
from ..models.model_manager import ModelManager
import torch
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length if max_length is None else max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
class BasePrompter:
def __init__(self):
self.refiners = []
self.extenders = []
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
for refiner_class in refiner_classes:
refiner = refiner_class.from_model_manager(model_manager)
self.refiners.append(refiner)
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
for extender_class in extender_classes:
extender = extender_class.from_model_manager(model_manager)
self.extenders.append(extender)
@torch.no_grad()
def process_prompt(self, prompt, positive=True):
if isinstance(prompt, list):
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
else:
for refiner in self.refiners:
prompt = refiner(prompt, positive=positive)
return prompt
@torch.no_grad()
def extend_prompt(self, prompt:str, positive=True):
extended_prompt = dict(prompt=prompt)
for extender in self.extenders:
extended_prompt = extender(extended_prompt)
return extended_prompt

View File

@@ -1,46 +0,0 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder2
from transformers import T5TokenizerFast
import os
class CogPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer")
super().__init__()
self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path)
self.text_encoder: FluxTextEncoder2 = None
def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
self.text_encoder = text_encoder
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device)
return prompt_emb

View File

@@ -1,74 +0,0 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
class FluxPrompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
tokenizer_2_path=None
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids.to(device)
_, pooled_prompt_emb = text_encoder(input_ids)
return pooled_prompt_emb
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
return prompt_emb, pooled_prompt_emb, text_ids

View File

@@ -1,353 +0,0 @@
from .base_prompter import BasePrompter
from ..models.model_manager import ModelManager
import json, os, re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from ..models.kolors_text_encoder import ChatGLMModel
class SPTokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
self.special_tokens = {}
self.index_special_tokens = {}
for token in special_tokens:
self.special_tokens[token] = self.n_words
self.index_special_tokens[self.n_words] = token
self.n_words += 1
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
def tokenize(self, s: str, encode_special_tokens=False):
if encode_special_tokens:
last_index = 0
t = []
for match in re.finditer(self.role_special_token_expression, s):
if last_index < match.start():
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
t.append(s[match.start():match.end()])
last_index = match.end()
if last_index < len(s):
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
return t
else:
return self.sp_model.EncodeAsPieces(s)
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
text, buffer = "", []
for token in t:
if token in self.index_special_tokens:
if buffer:
text += self.sp_model.decode(buffer)
buffer = []
text += self.index_special_tokens[token]
else:
buffer.append(token)
if buffer:
text += self.sp_model.decode(buffer)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self.sp_model.DecodePieces(tokens)
return text
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
return self.sp_model.PieceToId(token)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.index_special_tokens:
return self.index_special_tokens[index]
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
return ""
return self.sp_model.IdToPiece(index)
class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
**kwargs):
self.name = "GLMTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(vocab_file)
self.special_tokens = {
"<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id
}
self.encode_special_tokens = encode_special_tokens
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
encode_special_tokens=encode_special_tokens,
**kwargs)
def get_command(self, token):
if token in self.special_tokens:
return self.special_tokens[token]
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
return self.tokenizer.special_tokens[token]
@property
def unk_token(self) -> str:
return "<unk>"
@property
def pad_token(self) -> str:
return "<unk>"
@property
def pad_token_id(self):
return self.get_command("<pad>")
@property
def eos_token(self) -> str:
return "</s>"
@property
def eos_token_id(self):
return self.get_command("<eos>")
@property
def vocab_size(self):
return self.tokenizer.n_words
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)
def get_prefix_tokens(self):
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
return prefix_tokens
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message)
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, "", query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
prefix_tokens = self.get_prefix_tokens()
token_ids_0 = prefix_tokens + token_ids_0
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * seq_length
if "position_ids" not in encoded_inputs:
encoded_inputs["position_ids"] = list(range(seq_length))
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs
class KolorsPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
super().__init__()
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
self.text_encoder: ChatGLMModel = None
def fetch_models(self, text_encoder: ChatGLMModel = None):
self.text_encoder = text_encoder
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'] ,
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True
)
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
return prompt_emb, pooled_prompt_emb
def encode_prompt(
self,
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
return pooled_prompt_emb, prompt_emb

View File

@@ -1,323 +0,0 @@
from transformers import AutoTokenizer, TextIteratorStreamer
import difflib
import torch
import numpy as np
import re
from ..models.model_manager import ModelManager
from PIL import Image
valid_colors = { # r, g, b
'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
}
valid_locations = { # x, y in 90*90
'in the center': (45, 45),
'on the left': (15, 45),
'on the right': (75, 45),
'on the top': (45, 15),
'on the bottom': (45, 75),
'on the top-left': (15, 15),
'on the top-right': (75, 15),
'on the bottom-left': (15, 75),
'on the bottom-right': (75, 75)
}
valid_offsets = { # x, y in 90*90
'no offset': (0, 0),
'slightly to the left': (-10, 0),
'slightly to the right': (10, 0),
'slightly to the upper': (0, -10),
'slightly to the lower': (0, 10),
'slightly to the upper-left': (-10, -10),
'slightly to the upper-right': (10, -10),
'slightly to the lower-left': (-10, 10),
'slightly to the lower-right': (10, 10)}
valid_areas = { # w, h in 90*90
"a small square area": (50, 50),
"a small vertical area": (40, 60),
"a small horizontal area": (60, 40),
"a medium-sized square area": (60, 60),
"a medium-sized vertical area": (50, 80),
"a medium-sized horizontal area": (80, 50),
"a large square area": (70, 70),
"a large vertical area": (60, 90),
"a large horizontal area": (90, 60)
}
def safe_str(x):
return x.strip(',. ') + '.'
def closest_name(input_str, options):
input_str = input_str.lower()
closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
result = closest_match[0]
if result != input_str:
print(f'Automatically corrected [{input_str}] -> [{result}].')
return result
class Canvas:
@staticmethod
def from_bot_response(response: str):
matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
assert matched, 'Response does not contain codes!'
code_content = matched.group(1)
assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
local_vars = {'Canvas': Canvas}
exec(code_content, {}, local_vars)
canvas = local_vars.get('canvas', None)
assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
return canvas
def __init__(self):
self.components = []
self.color = None
self.record_tags = True
self.prefixes = []
self.suffixes = []
return
def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
HTML_web_color_name: str):
assert isinstance(description, str), 'Global description is not valid!'
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
'Global detailed_descriptions is not valid!'
assert isinstance(tags, str), 'Global tags is not valid!'
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
self.prefixes = [description]
self.suffixes = detailed_descriptions
if self.record_tags:
self.suffixes = self.suffixes + [tags]
self.prefixes = [safe_str(x) for x in self.prefixes]
self.suffixes = [safe_str(x) for x in self.suffixes]
return
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
detailed_descriptions: list, tags: str, atmosphere: str, style: str,
quality_meta: str, HTML_web_color_name: str):
assert isinstance(description, str), 'Local description is wrong!'
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
f'The distance_to_viewer for [{description}] is not positive float number!'
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
f'The detailed_descriptions for [{description}] is not valid!'
assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
assert isinstance(style, str), f'The style for [{description}] is not valid!'
assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
location = closest_name(location, valid_locations)
offset = closest_name(offset, valid_offsets)
area = closest_name(area, valid_areas)
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
xb, yb = valid_locations[location]
xo, yo = valid_offsets[offset]
w, h = valid_areas[area]
rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
rect = [max(0, min(90, i)) for i in rect]
color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
prefixes = self.prefixes + [description]
suffixes = detailed_descriptions
if self.record_tags:
suffixes = suffixes + [tags, atmosphere, style, quality_meta]
prefixes = [safe_str(x) for x in prefixes]
suffixes = [safe_str(x) for x in suffixes]
self.components.append(dict(
rect=rect,
distance_to_viewer=distance_to_viewer,
color=color,
prefixes=prefixes,
suffixes=suffixes,
location=location,
))
return
def process(self):
# sort components
self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
# compute initial latent
# print(self.color)
initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
for component in self.components:
a, b, c, d = component['rect']
initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
# compute conditions
bag_of_conditions = [
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
]
for i, component in enumerate(self.components):
a, b, c, d = component['rect']
m = np.zeros(shape=(90, 90), dtype=np.float32)
m[a:b, c:d] = 1.0
bag_of_conditions.append(dict(
mask = m,
prefixes = component['prefixes'],
suffixes = component['suffixes'],
location = component['location'],
))
return dict(
initial_latent = initial_latent,
bag_of_conditions = bag_of_conditions,
)
class OmostPromter(torch.nn.Module):
def __init__(self,model = None,tokenizer = None, template = "",device="cpu"):
super().__init__()
self.model=model
self.tokenizer = tokenizer
self.device = device
if template == "":
template = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
```python
class Canvas:
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
pass
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
assert distance_to_viewer > 0
pass
```'''
self.template = template
@staticmethod
def from_model_manager(model_manager: ModelManager):
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
omost = OmostPromter(
model= model,
tokenizer = tokenizer,
device = model_manager.device
)
return omost
def __call__(self,prompt_dict:dict):
raw_prompt=prompt_dict["prompt"]
conversation = [{"role": "system", "content": self.template}]
conversation.append({"role": "user", "content": raw_prompt})
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
generate_kwargs = dict(
input_ids = input_ids,
streamer = streamer,
# stopping_criteria=stopping_criteria,
# max_new_tokens=max_new_tokens,
do_sample = True,
attention_mask = attention_mask,
pad_token_id = self.tokenizer.eos_token_id,
# temperature=temperature,
# top_p=top_p,
)
self.model.generate(**generate_kwargs)
outputs = []
for text in streamer:
outputs.append(text)
llm_outputs = "".join(outputs)
canvas = Canvas.from_bot_response(llm_outputs)
canvas_output = canvas.process()
prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
canvas_output["prompt"] = prompts[0]
canvas_output["prompts"] = prompts[1:]
raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
masks=[]
for mask in raw_masks:
mask[mask>0.5]=255
mask = np.stack([mask] * 3, axis=-1).astype("uint8")
masks.append(Image.fromarray(mask))
canvas_output["masks"] = masks
prompt_dict.update(canvas_output)
print(f"Your prompt is extended by Omost:\n")
cnt = 0
for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
loc = component["location"]
cnt += 1
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
return prompt_dict

View File

@@ -1,130 +0,0 @@
from transformers import AutoTokenizer
from ..models.model_manager import ModelManager
import torch
from .omost import OmostPromter
class BeautifulPrompt(torch.nn.Module):
def __init__(self, tokenizer_path=None, model=None, template=""):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = template
@staticmethod
def from_model_manager(model_manager: ModelManager):
model, model_path = model_manager.fetch_model("beautiful_prompt", require_model_path=True)
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
if model_path.endswith("v2"):
template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
beautiful_prompt = BeautifulPrompt(
tokenizer_path=model_path,
model=model,
template=template
)
return beautiful_prompt
def __call__(self, raw_prompt, positive=True, **kwargs):
if positive:
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
return prompt
else:
return raw_prompt
class QwenPrompt(torch.nn.Module):
# This class leverages the open-source Qwen model to translate Chinese prompts into English,
# with an integrated optimization mechanism for enhanced translation quality.
def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.system_prompt = system_prompt
@staticmethod
def from_model_manager(model_nameger: ModelManager):
model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True)
system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
qwen_prompt = QwenPrompt(
tokenizer_path=model_path,
model=model,
system_prompt=system_prompt
)
return qwen_prompt
def __call__(self, raw_prompt, positive=True, **kwargs):
if positive:
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': raw_prompt
}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
model_inputs.input_ids,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"Your prompt is refined by Qwen: {prompt}")
return prompt
else:
return raw_prompt
class Translator(torch.nn.Module):
def __init__(self, tokenizer_path=None, model=None):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
@staticmethod
def from_model_manager(model_manager: ModelManager):
model, model_path = model_manager.fetch_model("translator", require_model_path=True)
translator = Translator(tokenizer_path=model_path, model=model)
return translator
def __call__(self, prompt, **kwargs):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
print(f"Your prompt is translated: {prompt}")
return prompt

View File

@@ -1,92 +0,0 @@
from .base_prompter import BasePrompter
from ..models.model_manager import ModelManager
from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
class SD3Prompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
tokenizer_2_path=None,
tokenizer_3_path=None
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
if tokenizer_3_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: SD3TextEncoder2 = None
self.text_encoder_3: SD3TextEncoder3 = None
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
self.text_encoder_3 = text_encoder_3
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids.to(device)
pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
return pooled_prompt_emb, prompt_emb
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
# T5
if self.text_encoder_3 is None:
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
# Merge
prompt_emb = torch.cat([
torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
prompt_emb_3
], dim=-2)
pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
return prompt_emb, pooled_prompt_emb

View File

@@ -1,73 +0,0 @@
from .base_prompter import BasePrompter, tokenize_long_prompt
from ..models.utils import load_state_dict, search_for_embeddings
from ..models import SDTextEncoder
from transformers import CLIPTokenizer
import torch, os
class SDPrompter(BasePrompter):
def __init__(self, tokenizer_path=None):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.text_encoder: SDTextEncoder = None
self.textual_inversion_dict = {}
self.keyword_dict = {}
def fetch_models(self, text_encoder: SDTextEncoder = None):
self.text_encoder = text_encoder
def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
dtype = next(iter(text_encoder.parameters())).dtype
state_dict = text_encoder.token_embedding.state_dict()
token_embeddings = [state_dict["weight"]]
for keyword in textual_inversion_dict:
_, embeddings = textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["weight"] = token_embeddings
text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
text_encoder.token_embedding.load_state_dict(state_dict)
def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
tokenizer.add_tokens(additional_tokens)
def load_textual_inversions(self, model_paths):
for model_path in model_paths:
keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
state_dict = load_state_dict(model_path)
# Search for embeddings
for embeddings in search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
prompt = self.process_prompt(prompt, positive=positive)
for keyword in self.keyword_dict:
if keyword in prompt:
print(f"Textual inversion {keyword} is enabled.")
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb

View File

@@ -1,61 +0,0 @@
from .base_prompter import BasePrompter, tokenize_long_prompt
from ..models.model_manager import ModelManager
from ..models import SDXLTextEncoder, SDXLTextEncoder2
from transformers import CLIPTokenizer
import torch, os
class SDXLPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None,
tokenizer_2_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
def encode_prompt(
self,
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
prompt_emb_1 = prompt_emb_1[: max_batch_size]
prompt_emb_2 = prompt_emb_2[: max_batch_size]
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb

View File

@@ -0,0 +1,3 @@
from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter

View File

@@ -1,34 +1,19 @@
from .base_prompter import BasePrompter
from ..models.model_manager import ModelManager
from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from transformers import BertTokenizer, AutoTokenizer
import warnings, os
from .utils import Prompter
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
import warnings
class HunyuanDiTPrompter(BasePrompter):
class HunyuanDiTPrompter(Prompter):
def __init__(
self,
tokenizer_path=None,
tokenizer_t5_path=None
tokenizer_path="configs/hunyuan_dit/tokenizer",
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
if tokenizer_t5_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
self.text_encoder = text_encoder
self.text_encoder_t5 = text_encoder_t5
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
@@ -52,6 +37,8 @@ class HunyuanDiTPrompter(BasePrompter):
def encode_prompt(
self,
text_encoder: BertModel,
text_encoder_t5: T5EncoderModel,
prompt,
clip_skip=1,
clip_skip_2=1,
@@ -61,9 +48,9 @@ class HunyuanDiTPrompter(BasePrompter):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
# T5
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5

View File

@@ -0,0 +1,17 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDTextEncoder
class SDPrompter(Prompter):
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb

View File

@@ -0,0 +1,43 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDXLTextEncoder, SDXLTextEncoder2
import torch
class SDXLPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
def encode_prompt(
self,
text_encoder: SDXLTextEncoder,
text_encoder_2: SDXLTextEncoder2,
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb

123
diffsynth/prompts/utils.py Normal file
View File

@@ -0,0 +1,123 @@
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import ModelManager
import os
def tokenize_long_prompt(tokenizer, prompt, max_length=99999999):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = max_length
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class Translator:
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
def __call__(self, prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return prompt
class Prompter:
def __init__(self):
self.tokenizer: CLIPTokenizer = None
self.keyword_dict = {}
self.translator: Translator = None
self.beautiful_prompt: BeautifulPrompt = None
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
def load_translator(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.translator = Translator(tokenizer_path=model_folder, model=model)
def load_from_model_manager(self, model_manager: ModelManager):
self.load_textual_inversion(model_manager.textual_inversion_dict)
if "translator" in model_manager.model:
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
if "beautiful_prompt" in model_manager.model:
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def process_prompt(self, prompt, positive=True):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
if positive and self.translator is not None:
prompt = self.translator(prompt)
print(f"Your prompt is translated: \"{prompt}\"")
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
return prompt

Some files were not shown because too many files have changed in this diff Show More