mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a076adf592 |
29
.github/workflows/publish.yaml
vendored
29
.github/workflows/publish.yaml
vendored
@@ -1,29 +0,0 @@
|
||||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-20.04
|
||||
#if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
||||
267
ExVideo_animatediff_train.py
Normal file
267
ExVideo_animatediff_train.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch, json, os, imageio
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
import lightning as pl
|
||||
from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel
|
||||
|
||||
|
||||
|
||||
def lets_dance(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel,
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
use_gradient_checkpointing=False,
|
||||
):
|
||||
# 1. ControlNet (skip)
|
||||
# 2. time
|
||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = unet.time_embedding(time_emb)
|
||||
|
||||
# 3. pre-process
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
# 4.1 UNet
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
# 4.2 AnimateDiff
|
||||
if block_id in motion_modules.call_block_id:
|
||||
motion_module_id = motion_modules.call_block_id[block_id]
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_modules.motion_modules[motion_module_id]),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
hidden_states = unet.conv_act(hidden_states)
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
|
||||
self.text = [i["text"] for i in metadata]
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.training_shapes = training_shapes
|
||||
|
||||
self.frame_process = []
|
||||
for max_num_frames, interval, num_frames, height, width in training_shapes:
|
||||
self.frame_process.append(v2.Compose([
|
||||
v2.Resize(size=max(height, width), antialias=True),
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||
]))
|
||||
|
||||
|
||||
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||
reader = imageio.get_reader(file_path)
|
||||
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||
reader.close()
|
||||
return None
|
||||
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = torch.tensor(frame, dtype=torch.float32)
|
||||
frame = rearrange(frame, "H W C -> 1 C H W")
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
|
||||
frames = torch.concat(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path, training_shape_id):
|
||||
data = {}
|
||||
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
|
||||
frame_process = self.frame_process[training_shape_id]
|
||||
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
|
||||
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
|
||||
if frames is None:
|
||||
return None
|
||||
else:
|
||||
data[f"frames_{training_shape_id}"] = frames
|
||||
data[f"start_frame_id_{training_shape_id}"] = start_frame_id
|
||||
return data
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
video_data = {}
|
||||
for training_shape_id in range(len(self.training_shapes)):
|
||||
while True:
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
if isinstance(text, list):
|
||||
text = text[torch.randint(0, len(text), (1,))[0]]
|
||||
video_file = self.path[data_id]
|
||||
try:
|
||||
data = self.load_video(video_file, training_shape_id)
|
||||
except:
|
||||
data = None
|
||||
if data is not None:
|
||||
data[f"text_{training_shape_id}"] = text
|
||||
break
|
||||
video_data.update(data)
|
||||
return video_data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, learning_rate=1e-5, sd_ckpt_path=None):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cpu")
|
||||
model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path))
|
||||
|
||||
# Initialize motion modules
|
||||
model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Build pipeline
|
||||
self.pipe = SDVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.vae_encoder.eval()
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
|
||||
self.pipe.vae_decoder.eval()
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
|
||||
self.pipe.text_encoder.eval()
|
||||
self.pipe.text_encoder.requires_grad_(False)
|
||||
|
||||
self.pipe.unet.eval()
|
||||
self.pipe.unet.requires_grad_(False)
|
||||
|
||||
self.pipe.motion_modules.train()
|
||||
self.pipe.motion_modules.requires_grad_(True)
|
||||
|
||||
# Reset the scheduler
|
||||
self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||
self.pipe.scheduler.set_timesteps(1000)
|
||||
|
||||
# Other parameters
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
|
||||
def encode_video_with_vae(self, video):
|
||||
video = video.to(device=self.device, dtype=self.dtype)
|
||||
video = video.unsqueeze(0)
|
||||
latents = self.pipe.vae_encoder.encode_video(video, batch_size=16)
|
||||
latents = rearrange(latents[0], "C T H W -> T C H W")
|
||||
return latents
|
||||
|
||||
|
||||
def calculate_loss(self, prompt, frames):
|
||||
with torch.no_grad():
|
||||
# Call video encoder
|
||||
latents = self.encode_video_with_vae(frames)
|
||||
|
||||
# Call text encoder
|
||||
prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77)
|
||||
prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1)
|
||||
|
||||
# Call scheduler
|
||||
timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0]
|
||||
noise = torch.randn_like(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
|
||||
# Calculate loss
|
||||
model_pred = lets_dance(
|
||||
self.pipe.unet, self.pipe.motion_modules,
|
||||
sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
||||
return loss
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Loss
|
||||
frames = batch["frames_0"][0]
|
||||
prompt = batch["text_0"][0]
|
||||
loss = self.calculate_loss(prompt, frames)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters()))
|
||||
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
|
||||
checkpoint["trainable_param_names"] = trainable_param_names
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# dataset and data loader
|
||||
dataset = TextVideoDataset(
|
||||
"/data/zhongjie/datasets/opensoraplan/data/processed",
|
||||
"/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json",
|
||||
training_shapes=[(16, 1, 16, 512, 512)],
|
||||
steps_per_epoch=7*10000,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=1,
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
# model
|
||||
model = LightningModel(
|
||||
learning_rate=1e-5,
|
||||
sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors",
|
||||
)
|
||||
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=100000,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
strategy="deepspeed_stage_1",
|
||||
precision="16-mixed",
|
||||
default_root_dir="/data/zhongjie/models/train_extended_animatediff",
|
||||
accumulate_grad_batches=1,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||
)
|
||||
trainer.fit(
|
||||
model=model,
|
||||
train_dataloaders=train_loader,
|
||||
ckpt_path=None
|
||||
)
|
||||
145
README.md
145
README.md
@@ -1,76 +1,67 @@
|
||||
# DiffSynth Studio
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
|
||||
|
||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
||||
|
||||
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||
|
||||
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
|
||||
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- The source codes are released in this project.
|
||||
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
|
||||
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||
|
||||
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
|
||||
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
- Demo videos are shown on Bilibili, including three tasks.
|
||||
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
|
||||
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
- Since OLSS requires additional training, we don't implement it in this project.
|
||||
|
||||
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
|
||||
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
## Roadmap
|
||||
|
||||
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
* Since OLSS requires additional training, we don't implement it in this project.
|
||||
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
|
||||
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
* Demo videos are shown on Bilibili, including three tasks.
|
||||
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
|
||||
* The source codes are released in this project.
|
||||
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
|
||||
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||
* Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
## Installation
|
||||
|
||||
Create Python environment:
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
|
||||
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
||||
|
||||
Enter the Python environment:
|
||||
|
||||
```
|
||||
conda activate DiffSynthStudio
|
||||
```
|
||||
|
||||
## Usage (in Python code)
|
||||
@@ -85,17 +76,15 @@ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5
|
||||
|
||||
### Image Synthesis
|
||||
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|
||||
|
||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
|Model|Example|
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|Stable Diffusion||
|
||||
|Stable Diffusion XL||
|
||||
|Stable Diffusion 3||
|
||||
|Kolors||
|
||||
|Hunyuan-DiT||
|
||||
|||
|
||||
|
||||
### Toon Shading
|
||||
|
||||
@@ -111,6 +100,22 @@ Video stylization without video models. [`examples/diffsynth`](./examples/diffsy
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Chinese Models
|
||||
|
||||
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
|
||||
|
||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||
|
||||
|1024x1024|2048x2048 (highres-fix)|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Usage (in WebUI)
|
||||
|
||||
```
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .prompters import *
|
||||
from .prompts import *
|
||||
from .schedulers import *
|
||||
from .pipelines import *
|
||||
from .controlnets import *
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
from ..models.sd_text_encoder import SDTextEncoder
|
||||
from ..models.sd_unet import SDUNet
|
||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
||||
|
||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from ..models.sdxl_unet import SDXLUNet
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from ..models.sd3_dit import SD3DiT
|
||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from ..models.sd_controlnet import SDControlNet
|
||||
|
||||
from ..models.sd_motion import SDMotionModel
|
||||
from ..models.sdxl_motion import SDXLMotionModel
|
||||
|
||||
from ..models.svd_image_encoder import SVDImageEncoder
|
||||
from ..models.svd_unet import SVDUNet
|
||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
|
||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
||||
]
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
"HunyuanDiT": [
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
}
|
||||
preset_models_on_modelscope = {
|
||||
# Hunyuan DiT
|
||||
"HunyuanDiT": [
|
||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
# Stable Video Diffusion
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# ExVideo
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"DreamShaper_8": [
|
||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"AingDiffusion_v12": [
|
||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"Flat2DAnimerge_v45Sharp": [
|
||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
# Textual Inversion
|
||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||
],
|
||||
# Stable Diffusion XL
|
||||
"StableDiffusionXL_v1": [
|
||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"BluePencilXL_v200": [
|
||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
"StableDiffusion3_without_T5": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
# ControlNet
|
||||
"ControlNet_v11f1p_sd15_depth": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11p_sd15_softedge": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11f1e_sd15_tile": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||
],
|
||||
"ControlNet_v11p_sd15_lineart": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
"AnimateDiff_xl_beta": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": [
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
# Translator
|
||||
"opus-mt-zh-en": [
|
||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||
],
|
||||
"IP-Adapter-SDXL": [
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
"stable-video-diffusion-img2vid-xt",
|
||||
"ExVideo-SVD-128f-v1",
|
||||
"StableDiffusion_v15",
|
||||
"DreamShaper_8",
|
||||
"AingDiffusion_v12",
|
||||
"Flat2DAnimerge_v45Sharp",
|
||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
||||
"StableDiffusionXL_v1",
|
||||
"BluePencilXL_v200",
|
||||
"StableDiffusionXL_Turbo",
|
||||
"ControlNet_v11f1p_sd15_depth",
|
||||
"ControlNet_v11p_sd15_softedge",
|
||||
"ControlNet_v11f1e_sd15_tile",
|
||||
"ControlNet_v11p_sd15_lineart",
|
||||
"AnimateDiff_v2",
|
||||
"AnimateDiff_xl_beta",
|
||||
"RIFE",
|
||||
"BeautifulPrompt",
|
||||
"opus-mt-zh-en",
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
"SDXL-vae-fp16-fix",
|
||||
]
|
||||
@@ -12,24 +12,24 @@ Processor_id: TypeAlias = Literal[
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import torch, os
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
@@ -99,8 +99,7 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
return flow_list, mask_list[2], merged
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return IFNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -1 +1,482 @@
|
||||
from .model_manager import *
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .sd_lora import SDLoRA
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = {}
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_stable_video_diffusion(self, state_dict):
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_RIFE(self, state_dict):
|
||||
param_name = "block_tea.convblock3.0.1.weight"
|
||||
return param_name in state_dict or ("module." + param_name) in state_dict
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||
return param_name in state_dict or param_name_2 in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff_xl(self, state_dict):
|
||||
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_sd_lora(self, state_dict):
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_translator(self, state_dict):
|
||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||
return param_name in state_dict and len(state_dict) == 254
|
||||
|
||||
def is_ipadapter(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
||||
|
||||
def is_ipadapter_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
||||
return param_name in state_dict and len(state_dict) == 521
|
||||
|
||||
def is_ipadapter_xl(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
||||
|
||||
def is_ipadapter_xl_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||
return param_name in state_dict and len(state_dict) == 777
|
||||
|
||||
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
||||
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit(self, state_dict):
|
||||
param_name = "final_layer.adaLN_modulation.1.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_diffusers_vae(self, state_dict):
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||
param_name = "blocks.185.positional_embedding.embeddings"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
"unet": SVDUNet,
|
||||
"vae_decoder": SVDVAEDecoder,
|
||||
"vae_encoder": SVDVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
if component == "unet":
|
||||
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
"unet": SDUNet,
|
||||
"vae_decoder": SDVAEDecoder,
|
||||
"vae_encoder": SDVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
if component == "text_encoder":
|
||||
# Add additional token embeddings to text encoder
|
||||
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||
for keyword in self.textual_inversion_dict:
|
||||
_, embeddings = self.textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDXLTextEncoder,
|
||||
"text_encoder_2": SDXLTextEncoder2,
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component in ["vae_decoder", "vae_encoder"]:
|
||||
# These two model will output nan when float16 is enabled.
|
||||
# The precision problem happens in the last three resnet blocks.
|
||||
# I do not know how to solve this problem.
|
||||
self.model[component].to(torch.float32).to(self.device)
|
||||
else:
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_controlnet(self, state_dict, file_path=""):
|
||||
component = "controlnet"
|
||||
if component not in self.model:
|
||||
self.model[component] = []
|
||||
self.model_path[component] = []
|
||||
model = SDControlNet()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component].append(model)
|
||||
self.model_path[component].append(file_path)
|
||||
|
||||
def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
|
||||
component = "motion_modules"
|
||||
model = SDMotionModel(add_positional_conv=add_positional_conv)
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_animatediff_xl(self, state_dict, file_path=""):
|
||||
component = "motion_modules_xl"
|
||||
model = SDXLMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_RIFE(self, state_dict, file_path=""):
|
||||
component = "RIFE"
|
||||
from ..extensions.RIFE import IFNet
|
||||
model = IFNet().eval()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_sd_lora(self, state_dict, alpha):
|
||||
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||
|
||||
def load_translator(self, state_dict, file_path=""):
|
||||
# This model is lightweight, we do not place it on GPU.
|
||||
component = "translator"
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter(self, state_dict, file_path=""):
|
||||
component = "ipadapter"
|
||||
model = SDIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_image_encoder"
|
||||
model = IpAdapterCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl"
|
||||
model = SDXLIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl_image_encoder"
|
||||
model = IpAdapterXLCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_clip_text_encoder"
|
||||
model = HunyuanDiTCLIPTextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_t5_text_encoder"
|
||||
model = HunyuanDiTT5TextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit"
|
||||
model = HunyuanDiT()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_diffusers_vae(self, state_dict, file_path=""):
|
||||
# TODO: detect SD and SDXL
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||
unet_state_dict = self.model["unet"].state_dict()
|
||||
self.model["unet"].to("cpu")
|
||||
del self.model["unet"]
|
||||
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += self.search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
def load_textual_inversions(self, folder):
|
||||
# Store additional tokens here
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
if file_name.endswith(".txt"):
|
||||
continue
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
|
||||
# Search for embeddings
|
||||
for embeddings in self.search_for_embeddings(state_dict):
|
||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff_xl(state_dict):
|
||||
self.load_animatediff_xl(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
elif self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_sd_lora(state_dict):
|
||||
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
elif self.is_RIFE(state_dict):
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
elif self.is_translator(state_dict):
|
||||
self.load_translator(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter(state_dict):
|
||||
self.load_ipadapter(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_image_encoder(state_dict):
|
||||
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl(state_dict):
|
||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit(state_dict):
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, lora_alphas=lora_alphas)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
if isinstance(self.model[component], list):
|
||||
for model in self.model[component]:
|
||||
model.to(device)
|
||||
else:
|
||||
self.model[component].to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_model_with_model_path(self, model_path):
|
||||
for component in self.model_path:
|
||||
if isinstance(self.model_path[component], str):
|
||||
if os.path.samefile(self.model_path[component], model_path):
|
||||
return self.model[component]
|
||||
elif isinstance(self.model_path[component], list):
|
||||
for i, model_path_ in enumerate(self.model_path[component]):
|
||||
if os.path.samefile(model_path_, model_path):
|
||||
return self.model[component][i]
|
||||
raise ValueError(f"Please load model {model_path} before you use it.")
|
||||
|
||||
def __getattr__(self, __name):
|
||||
if __name in self.model:
|
||||
return self.model[__name]
|
||||
else:
|
||||
return super.__getattribute__(__name)
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
|
||||
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
"ModelScope",
|
||||
]
|
||||
website_to_preset_models = {
|
||||
"HuggingFace": preset_models_on_huggingface,
|
||||
"ModelScope": preset_models_on_modelscope,
|
||||
}
|
||||
website_to_download_fn = {
|
||||
"HuggingFace": download_from_huggingface,
|
||||
"ModelScope": download_from_modelscope,
|
||||
}
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
@@ -1,4 +1,5 @@
|
||||
from .attention import Attention
|
||||
from .tiler import TileWorker
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
@@ -398,8 +399,7 @@ class HunyuanDiT(torch.nn.Module):
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -79,8 +79,7 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
@@ -132,8 +131,7 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,195 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNet
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
|
||||
class LoRAFromCivitai:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = []
|
||||
self.lora_prefix = []
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||
for special_key in self.special_keys:
|
||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||
if model_resource == "diffusers":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||
elif model_resource == "civitai":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers", "civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return lora_prefix, model_resource
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
}
|
||||
|
||||
|
||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||
}
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
for name, param in state_dict_model.items():
|
||||
torch_dtype = param.dtype
|
||||
device = param.device
|
||||
break
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
@@ -1,536 +0,0 @@
|
||||
import os, torch, hashlib, json, importlib
|
||||
from safetensors import safe_open
|
||||
from torch import Tensor
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from .sd3_dit import SD3DiT
|
||||
from .sd3_vae_decoder import SD3VAEDecoder
|
||||
from .sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_for_embeddings(state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
|
||||
|
||||
def search_for_files(folder, extensions):
|
||||
files = []
|
||||
if os.path.isdir(folder):
|
||||
for file in sorted(os.listdir(folder)):
|
||||
files += search_for_files(os.path.join(folder, file), extensions)
|
||||
elif os.path.isfile(folder):
|
||||
for extension in extensions:
|
||||
if folder.endswith(extension):
|
||||
files.append(folder)
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||
state_dict_converter = model_class.state_dict_converter()
|
||||
if model_resource == "civitai":
|
||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||
elif model_resource == "diffusers":
|
||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||
if isinstance(state_dict_results, tuple):
|
||||
model_state_dict, extra_kwargs = state_dict_results
|
||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
||||
model.load_state_dict(model_state_dict)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
model = model.to(device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||
base_state_dict = base_model.state_dict()
|
||||
base_model.to("cpu")
|
||||
del base_model
|
||||
model = model_class(**extra_kwargs)
|
||||
model.load_state_dict(base_state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
while True:
|
||||
for model_id in range(len(model_manager.model)):
|
||||
base_model_name = model_manager.model_name[model_id]
|
||||
if base_model_name == model_name:
|
||||
base_model_path = model_manager.model_path[model_id]
|
||||
base_model = model_manager.model[model_id]
|
||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||
patched_model = load_single_patch_model_from_single_file(
|
||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||
loaded_model_names.append(base_model_name)
|
||||
loaded_models.append(patched_model)
|
||||
model_manager.model.pop(model_id)
|
||||
model_manager.model_path.pop(model_id)
|
||||
model_manager.model_name.pop(model_id)
|
||||
break
|
||||
else:
|
||||
break
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorTemplate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
return False
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
self.keys_hash_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||
if keys_hash is not None:
|
||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
# Load models without strict matching
|
||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
super().__init__(model_loader_configs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
# Split the state_dict and load from each component
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
valid_state_dict = {}
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
valid_state_dict.update(sub_state_dict)
|
||||
if super().match(file_path, valid_state_dict):
|
||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
else:
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromHuggingfaceFolder:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.architecture_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
return False
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if "architectures" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for architecture in config["architectures"]:
|
||||
huggingface_lib, model_name = self.architecture_dict[architecture]
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromPatchedSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
loaded_model_names, loaded_models = [], []
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda",
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
file_path_list: List[str] = [],
|
||||
):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = []
|
||||
self.model_path = []
|
||||
self.model_name = []
|
||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
||||
self.model_detector = [
|
||||
ModelDetectorFromSingleFile(model_loader_configs),
|
||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||
]
|
||||
self.load_models(downloaded_files + file_path_list)
|
||||
|
||||
|
||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||
print(f"Loading models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||
print(f"Loading models from folder: {file_path}")
|
||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||
print(f"Loading patch models from file: {file_path}")
|
||||
model_names, models = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following patched models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None):
|
||||
print(f"Loading models from: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
for model_detector in self.model_detector:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=self.device, torch_dtype=self.torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
break
|
||||
else:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if file_path is not None and file_path != model_path:
|
||||
continue
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available.")
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
else:
|
||||
return fetched_models[0]
|
||||
|
||||
|
||||
def to(self, device):
|
||||
for model in self.model:
|
||||
model.to(device)
|
||||
|
||||
@@ -1,798 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from .svd_unet import TemporalTimesteps
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
||||
return spatial_pos_embed
|
||||
|
||||
def forward(self, latent):
|
||||
height, width = latent.shape[-2:]
|
||||
latent = self.proj(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
return latent + pos_embed
|
||||
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
return time_emb
|
||||
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
|
||||
class JointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.only_out_a = only_out_a
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerFinalBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class SD3DiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
|
||||
self.time_embedder = TimestepEmbeddings(256, 1536)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
|
||||
self.context_embedder = torch.nn.Linear(4096, 1536)
|
||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
|
||||
self.norm_out = AdaLayerNorm(1536, single=True)
|
||||
self.proj_out = torch.nn.Linear(1536, 64)
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.pos_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SD3DiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class SD3DiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
||||
"pos_embed.proj": "pos_embedder.proj",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "norm_out.linear",
|
||||
"proj_out": "proj_out",
|
||||
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "pos_embed.pos_embed":
|
||||
param = param.reshape((1, 192, 192, 1536))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in rename_dict:
|
||||
state_dict_[rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
|
||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
||||
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
||||
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
||||
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
||||
elif name == "model.diffusion_model.pos_embed":
|
||||
param = param.reshape((1, 192, 192, 1536))
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,81 +0,0 @@
|
||||
import torch
|
||||
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class SD3VAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
@@ -1,95 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SD3VAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states[:, :16]
|
||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
@@ -99,7 +99,7 @@ class SDControlNet(torch.nn.Module):
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
@@ -134,8 +134,7 @@ class SDControlNet(torch.nn.Module):
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDControlNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class SDIpAdapter(torch.nn.Module):
|
||||
|
||||
def set_less_adapter(self):
|
||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
||||
self.set_full_adapter()
|
||||
self.set_full_adapter(self)
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
@@ -47,8 +47,7 @@ class SDIpAdapter(torch.nn.Module):
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
60
diffsynth/models/sd_lora.py
Normal file
60
diffsynth/models/sd_lora.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
||||
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
||||
|
||||
|
||||
class SDLoRA:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
||||
special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
||||
for special_key in special_keys:
|
||||
target_name = target_name.replace(special_key, special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_unet = unet.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
||||
unet.load_state_dict(state_dict_unet)
|
||||
|
||||
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_text_encoder = text_encoder.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
||||
text_encoder.load_state_dict(state_dict_text_encoder)
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
from .sd_unet import SDUNet, Attention, GEGLU
|
||||
from .svd_unet import get_timestep_embedding
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class TemporalTransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
|
||||
super().__init__()
|
||||
self.add_positional_conv = add_positional_conv
|
||||
|
||||
# 1. Self-Attn
|
||||
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||
self.pe1 = torch.nn.Parameter(emb)
|
||||
if add_positional_conv:
|
||||
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
# 2. Cross-Attn
|
||||
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||
self.pe2 = torch.nn.Parameter(emb)
|
||||
if add_positional_conv:
|
||||
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
|
||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||
|
||||
|
||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||
if frame_id < max_id:
|
||||
position_id = frame_id
|
||||
else:
|
||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||
if position_id < repeat_length:
|
||||
position_id = max_id - 2 - position_id
|
||||
else:
|
||||
position_id = max_id - 2 * repeat_length + position_id
|
||||
return position_id
|
||||
|
||||
|
||||
def positional_ids(self, num_frames):
|
||||
max_id = self.pe1.shape[1]
|
||||
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||
return positional_ids
|
||||
|
||||
|
||||
def forward(self, hidden_states, batch_size=1):
|
||||
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
||||
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||
if self.add_positional_conv:
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn1(norm_hidden_states)
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
||||
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||
if self.add_positional_conv:
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn2(norm_hidden_states)
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
|
||||
|
||||
class TemporalBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
|
||||
TemporalTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim
|
||||
attention_head_dim,
|
||||
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
|
||||
add_positional_conv=add_positional_conv
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class SDMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, add_positional_conv=None):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
])
|
||||
self.call_block_id = {
|
||||
1: 0,
|
||||
@@ -144,8 +182,7 @@ class SDMotionModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
@@ -153,7 +190,42 @@ class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||
if frame_id < max_id:
|
||||
position_id = frame_id
|
||||
else:
|
||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||
if position_id < repeat_length:
|
||||
position_id = max_id - 2 - position_id
|
||||
else:
|
||||
position_id = max_id - 2 * repeat_length + position_id
|
||||
return position_id
|
||||
|
||||
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
|
||||
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
|
||||
for i in range(21):
|
||||
# Extend positional embedding
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
|
||||
state_dict[name] = state_dict[name][:, ids]
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
|
||||
state_dict[name] = state_dict[name][:, ids]
|
||||
# add post convolution
|
||||
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
|
||||
state_dict[name] = torch.zeros((dim,))
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
|
||||
state_dict[name] = torch.zeros((dim,))
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
|
||||
param = torch.zeros((dim, dim, 3))
|
||||
param[:, :, 1] = torch.eye(dim, dim)
|
||||
state_dict[name] = param
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
|
||||
param = torch.zeros((dim, dim, 3))
|
||||
param[:, :, 1] = torch.eye(dim, dim)
|
||||
state_dict[name] = param
|
||||
return state_dict
|
||||
|
||||
def from_diffusers(self, state_dict, add_positional_conv=None):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
@@ -193,7 +265,9 @@ class SDMotionModelStateDictConverter:
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
if add_positional_conv is not None:
|
||||
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
def from_civitai(self, state_dict, add_positional_conv=None):
|
||||
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)
|
||||
|
||||
115
diffsynth/models/sd_motion_ex.py
Normal file
115
diffsynth/models/sd_motion_ex.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from .attention import Attention
|
||||
from .svd_unet import get_timestep_embedding
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ExVideoMotionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
|
||||
super().__init__()
|
||||
|
||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
|
||||
self.positional_embedding = torch.nn.Parameter(emb)
|
||||
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
|
||||
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
|
||||
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
|
||||
|
||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||
if frame_id < max_id:
|
||||
position_id = frame_id
|
||||
else:
|
||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||
if position_id < repeat_length:
|
||||
position_id = max_id - 2 - position_id
|
||||
else:
|
||||
position_id = max_id - 2 * repeat_length + position_id
|
||||
return position_id
|
||||
|
||||
def positional_ids(self, num_frames):
|
||||
max_id = self.positional_embedding.shape[0]
|
||||
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||
return positional_ids
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
|
||||
batch, inner_dim, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
pos_emb = self.positional_ids(batch // batch_size)
|
||||
pos_emb = self.positional_embedding[pos_emb]
|
||||
pos_emb = pos_emb.repeat(batch_size)
|
||||
hidden_states = hidden_states + pos_emb
|
||||
if self.positional_conv is not None:
|
||||
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
|
||||
hidden_states = self.positional_conv(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
|
||||
else:
|
||||
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
|
||||
|
||||
for norm, attn in zip(self.norms, self.attns):
|
||||
norm_hidden_states = norm(hidden_states)
|
||||
attn_output = attn(norm_hidden_states)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
|
||||
class ExVideoMotionModel(torch.nn.Module):
|
||||
def __init__(self, num_layers=2):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||
])
|
||||
self.call_block_id = {
|
||||
1: 0,
|
||||
4: 1,
|
||||
9: 2,
|
||||
12: 3,
|
||||
17: 4,
|
||||
20: 5,
|
||||
24: 6,
|
||||
26: 7,
|
||||
29: 8,
|
||||
32: 9,
|
||||
34: 10,
|
||||
36: 11,
|
||||
40: 12,
|
||||
43: 13,
|
||||
46: 14,
|
||||
50: 15,
|
||||
53: 16,
|
||||
56: 17,
|
||||
60: 18,
|
||||
63: 19,
|
||||
66: 20
|
||||
}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
pass
|
||||
@@ -71,8 +71,7 @@ class SDTextEncoder(torch.nn.Module):
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
|
||||
# 2. pre-process
|
||||
@@ -342,8 +342,7 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -90,8 +90,6 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -112,12 +110,10 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -50,8 +50,6 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -73,7 +71,6 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -94,8 +91,7 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -96,8 +96,7 @@ class SDXLIpAdapter(torch.nn.Module):
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDXLIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -49,8 +49,7 @@ class SDXLMotionModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -36,8 +36,7 @@ class SDXLTextEncoder(torch.nn.Module):
|
||||
break
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDXLTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
@@ -81,8 +80,7 @@ class SDXLTextEncoder2(torch.nn.Module):
|
||||
pooled_embeds = self.text_projection(pooled_embeds)
|
||||
return pooled_embeds, hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDXLTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
||||
|
||||
|
||||
class SDXLUNet(torch.nn.Module):
|
||||
def __init__(self, is_kolors=False):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -13,12 +13,11 @@ class SDXLUNet(torch.nn.Module):
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
@@ -86,12 +85,10 @@ class SDXLUNet(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=8,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(add_time_id)
|
||||
@@ -105,26 +102,15 @@ class SDXLUNet(torch.nn.Module):
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -133,8 +119,7 @@ class SDXLUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDXLUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -163,8 +148,6 @@ class SDXLUNetStateDictConverter:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif names[0] in ["encoder_hid_proj"]:
|
||||
names[0] = "text_intermediate_proj"
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
@@ -198,10 +181,7 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
@@ -1893,7 +1873,4 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
return state_dict_
|
||||
|
||||
@@ -2,23 +2,14 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class SDXLVAEDecoder(SDVAEDecoder):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDXLVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
@@ -2,23 +2,14 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SDXLVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SDXLVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
@@ -44,8 +44,7 @@ class SVDImageEncoder(torch.nn.Module):
|
||||
embeds = self.visual_projection(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SVDImageEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -407,8 +407,7 @@ class SVDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SVDUNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -199,8 +199,7 @@ class SVDVAEDecoder(torch.nn.Module):
|
||||
return values
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SVDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ class SVDVAEEncoder(SDVAEEncoder):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
def state_dict_converter(self):
|
||||
return SVDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from .sd_image import SDImagePipeline
|
||||
from .sd_video import SDVideoPipeline
|
||||
from .sdxl_image import SDXLImagePipeline
|
||||
from .sdxl_video import SDXLVideoPipeline
|
||||
from .sd3_image import SD3ImagePipeline
|
||||
from .hunyuan_image import HunyuanDiTImagePipeline
|
||||
from .svd_video import SVDVideoPipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_images(self, images):
|
||||
return [self.preprocess_image(image) for image in images]
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output):
|
||||
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output):
|
||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
||||
return video
|
||||
|
||||
@@ -22,10 +22,6 @@ def lets_dance(
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 0. Text embedding alignment (only for video processing)
|
||||
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
||||
|
||||
# 1. ControlNet
|
||||
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
||||
# I leave it here because I intend to do something interesting on the ControlNets.
|
||||
@@ -54,7 +50,7 @@ def lets_dance(
|
||||
additional_res_stack = None
|
||||
|
||||
# 2. time
|
||||
time_emb = unet.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = unet.time_embedding(time_emb)
|
||||
|
||||
# 3. pre-process
|
||||
@@ -137,7 +133,7 @@ def lets_dance_xl(
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 2. time
|
||||
t_emb = unet.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = unet.time_embedding(t_emb)
|
||||
|
||||
time_embeds = unet.add_time_proj(add_time_id)
|
||||
@@ -151,7 +147,7 @@ def lets_dance_xl(
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
|
||||
@@ -3,11 +3,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models import ModelManager
|
||||
from ..prompters import HunyuanDiTPrompter
|
||||
from ..prompts import HunyuanDiTPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -122,12 +122,14 @@ class ImageSizeManager:
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTImagePipeline(BasePipeline):
|
||||
class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||
self.prompter = HunyuanDiTPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.image_size_manager = ImageSizeManager()
|
||||
# models
|
||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||
@@ -137,60 +139,42 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
|
||||
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
|
||||
self.dit = model_manager.hunyuan_dit
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
|
||||
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
|
||||
self.dit = model_manager.fetch_model("hunyuan_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
pipe = HunyuanDiTImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
|
||||
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=positive,
|
||||
device=self.device
|
||||
)
|
||||
return {
|
||||
"text_emb": text_emb,
|
||||
"text_emb_mask": text_emb_mask,
|
||||
"text_emb_t5": text_emb_t5,
|
||||
"text_emb_mask_t5": text_emb_mask_t5
|
||||
}
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
|
||||
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
|
||||
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
|
||||
if tiled:
|
||||
height, width = tile_size * 16, tile_size * 16
|
||||
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
||||
@@ -214,6 +198,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
input_image=None,
|
||||
reference_images=[],
|
||||
reference_strengths=[0.4],
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
@@ -231,32 +216,71 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise.clone()
|
||||
|
||||
# Prepare reference latents
|
||||
reference_latents = []
|
||||
for reference_image in reference_images:
|
||||
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=True,
|
||||
device=self.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=False,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
||||
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# In-context reference
|
||||
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
|
||||
if progress_id < num_inference_steps * reference_strength:
|
||||
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
|
||||
self.dit(
|
||||
noisy_reference_latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
to_cache=True
|
||||
)
|
||||
# Positive side
|
||||
noise_pred_posi = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **extra_input,
|
||||
latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
|
||||
latents,
|
||||
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
|
||||
timestep,
|
||||
**extra_input
|
||||
)
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
@@ -269,6 +293,6 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,105 +0,0 @@
|
||||
import os, torch, json
|
||||
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
|
||||
from ..processors.sequencial_processor import SequencialProcessor
|
||||
from ..data import VideoData, save_frames, save_video
|
||||
|
||||
|
||||
|
||||
class SDVideoPipelineRunner:
|
||||
def __init__(self, in_streamlit=False):
|
||||
self.in_streamlit = in_streamlit
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_models(model_list)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
textual_inversion_paths = []
|
||||
for file_name in os.listdir(textual_inversion_folder):
|
||||
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
|
||||
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
|
||||
pipe.prompter.load_textual_inversions(textual_inversion_paths)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def load_smoother(self, model_manager, smoother_configs):
|
||||
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
||||
return smoother
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
else:
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
if start_frame_id is None:
|
||||
start_frame_id = 0
|
||||
if end_frame_id is None:
|
||||
end_frame_id = len(video)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||
if "smoother_configs" in config:
|
||||
if self.in_streamlit: st.markdown("Loading smoother ...")
|
||||
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
||||
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
||||
else:
|
||||
smoother = None
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
if self.in_streamlit: st.video(video_file.read())
|
||||
@@ -1,132 +0,0 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
|
||||
from ..prompters import SD3Prompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class SD3ImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.prompter = SD3Prompter()
|
||||
# models
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: SD3TextEncoder2 = None
|
||||
self.text_encoder_3: SD3TextEncoder3 = None
|
||||
self.dit: SD3DiT = None
|
||||
self.vae_decoder: SD3VAEDecoder = None
|
||||
self.vae_encoder: SD3VAEEncoder = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
|
||||
if "sd3_text_encoder_3" in model_manager.model:
|
||||
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
|
||||
self.dit = model_manager.fetch_model("sd3_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
pipe = SD3ImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
||||
)
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,266 +0,0 @@
|
||||
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .sd_image import SDImagePipeline
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
ipadapter_kwargs_list = {},
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
device="cuda",
|
||||
animatediff_batch_size=16,
|
||||
animatediff_stride=8,
|
||||
):
|
||||
num_frames = sample.shape[0]
|
||||
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
||||
|
||||
for batch_id in range(0, num_frames, animatediff_stride):
|
||||
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
||||
|
||||
# process this batch
|
||||
hidden_states_batch = lets_dance(
|
||||
unet, motion_modules, controlnet,
|
||||
sample[batch_id: batch_id_].to(device),
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list,
|
||||
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
|
||||
).cpu()
|
||||
|
||||
# update hidden_states
|
||||
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
||||
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
||||
hidden_states, num = hidden_states_output[i]
|
||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||
hidden_states_output[i] = (hidden_states, num + bias)
|
||||
|
||||
if batch_id_ == num_frames:
|
||||
break
|
||||
|
||||
# output
|
||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class SDVideoPipeline(SDImagePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
||||
self.prompter = SDPrompter()
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
self.vae_decoder: SDVAEDecoder = None
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
||||
self.ipadapter: SDIpAdapter = None
|
||||
self.motion_modules: SDMotionModel = None
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sd_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device),
|
||||
model_manager.fetch_model("sd_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
||||
|
||||
# Motion Modules
|
||||
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
|
||||
if self.motion_modules is None:
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents.append(latent.cpu())
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters, batch size ...
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
other_kwargs = {
|
||||
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
|
||||
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
|
||||
"cross_frame_attention": cross_frame_attention,
|
||||
}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_video(input_frames, **tiler_kwargs)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||
rendered_frames = self.decode_video(rendered_frames)
|
||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||
target_latents = self.encode_video(rendered_frames)
|
||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
output_frames = self.decode_video(latents, **tiler_kwargs)
|
||||
|
||||
# Post-process
|
||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||
|
||||
return output_frames
|
||||
@@ -1,191 +0,0 @@
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDXLPrompter, KolorsPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .base import BasePipeline
|
||||
from .dancer import lets_dance_xl
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDXLPrompter()
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.text_encoder_kolors: ChatGLMModel = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.unet
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
||||
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sdxl_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
|
||||
# ControlNets (TODO)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
||||
|
||||
# Kolors
|
||||
if self.text_encoder_kolors is not None:
|
||||
print("Switch to Kolors. The prompter and scheduler will be replaced.")
|
||||
self.prompter = KolorsPrompter()
|
||||
self.prompter.fetch_models(self.text_encoder_kolors)
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
||||
else:
|
||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
|
||||
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=positive,
|
||||
)
|
||||
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
||||
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_use_instant_style=False,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
if ipadapter_use_instant_style:
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets (TODO)
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Prepare extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,223 +0,0 @@
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDXLPrompter, KolorsPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .sdxl_image import SDXLImagePipeline
|
||||
from .dancer import lets_dance_xl
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class SDXLVideoPipeline(SDXLImagePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
||||
self.prompter = SDXLPrompter()
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.text_encoder_kolors: ChatGLMModel = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
self.motion_modules: SDXLMotionModel = None
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
||||
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sdxl_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
# ControlNets (TODO)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
||||
|
||||
# Motion Modules
|
||||
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
|
||||
if self.motion_modules is None:
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||
|
||||
# Kolors
|
||||
if self.text_encoder_kolors is not None:
|
||||
print("Switch to Kolors. The prompter will be replaced.")
|
||||
self.prompter = KolorsPrompter()
|
||||
self.prompter.fetch_models(self.text_encoder_kolors)
|
||||
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
|
||||
if self.motion_modules is None:
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
||||
else:
|
||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
pipe = SDXLVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents.append(latent.cpu())
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_use_instant_style=False,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters, batch size ...
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_video(input_frames, **tiler_kwargs)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
latents = latents.to(self.device) # will be deleted for supporting long videos
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
if ipadapter_use_instant_style:
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Prepare extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||
rendered_frames = self.decode_video(rendered_frames)
|
||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||
target_latents = self.encode_video(rendered_frames)
|
||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
output_frames = self.decode_video(latents, **tiler_kwargs)
|
||||
|
||||
# Post-process
|
||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||
|
||||
return output_frames
|
||||
@@ -1,22 +1,23 @@
|
||||
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import SDPrompter
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .base import BasePipeline
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class SDImagePipeline(BasePipeline):
|
||||
class SDImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
@@ -27,63 +28,59 @@ class SDImagePipeline(BasePipeline):
|
||||
self.ipadapter: SDIpAdapter = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.unet
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
||||
self.unet = model_manager.fetch_model("sd_unet")
|
||||
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
# ControlNets
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device),
|
||||
model_manager.fetch_model("sd_controlnet", config.model_path),
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
||||
|
||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||
if "ipadapter" in model_manager.model:
|
||||
self.ipadapter = model_manager.ipadapter
|
||||
if "ipadapter_image_encoder" in model_manager.model:
|
||||
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
pipe.fetch_ipadapter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, positive=True):
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
|
||||
return {"encoder_hidden_states": prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -107,56 +104,53 @@ class SDImagePipeline(BasePipeline):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||
device=self.device, vram_limit_level=0
|
||||
)
|
||||
noise_pred_nega = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
device=self.device,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||
device=self.device, vram_limit_level=0
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
371
diffsynth/pipelines/stable_diffusion_video.py
Normal file
371
diffsynth/pipelines/stable_diffusion_video.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from ..data import VideoData, save_frames, save_video
|
||||
from .dancer import lets_dance
|
||||
from ..processors.sequencial_processor import SequencialProcessor
|
||||
from typing import List
|
||||
import torch, os, json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
num_frames = sample.shape[0]
|
||||
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
||||
|
||||
for batch_id in range(0, num_frames, animatediff_stride):
|
||||
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
||||
|
||||
# process this batch
|
||||
hidden_states_batch = lets_dance(
|
||||
unet, motion_modules, controlnet,
|
||||
sample[batch_id: batch_id_].to(device),
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_].to(device),
|
||||
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=device, vram_limit_level=vram_limit_level
|
||||
).cpu()
|
||||
|
||||
# update hidden_states
|
||||
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
||||
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
||||
hidden_states, num = hidden_states_output[i]
|
||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||
hidden_states_output[i] = (hidden_states, num + bias)
|
||||
|
||||
if batch_id_ == num_frames:
|
||||
break
|
||||
|
||||
# output
|
||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||
self.prompter = SDPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
self.vae_decoder: SDVAEDecoder = None
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.motion_modules: SDMotionModel = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||
if "motion_modules" in model_manager.model:
|
||||
self.motion_modules = model_manager.motion_modules
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
use_animatediff="motion_modules" in model_manager.model
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||
latents.append(latent)
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
||||
if post_normalize:
|
||||
mean, std = latents.mean(), latents.std()
|
||||
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
||||
latents = latents * contrast_enhance_scale
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
vram_limit_level=0,
|
||||
post_normalize=False,
|
||||
contrast_enhance_scale=1.0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
|
||||
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
||||
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||
rendered_frames = self.decode_images(rendered_frames)
|
||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||
target_latents = self.encode_images(rendered_frames)
|
||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
||||
output_frames = self.decode_images(latents)
|
||||
|
||||
# Post-process
|
||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||
|
||||
return output_frames
|
||||
|
||||
|
||||
|
||||
class SDVideoPipelineRunner:
|
||||
def __init__(self, in_streamlit=False):
|
||||
self.in_streamlit = in_streamlit
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def load_smoother(self, model_manager, smoother_configs):
|
||||
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
||||
return smoother
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
else:
|
||||
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
if start_frame_id is None:
|
||||
start_frame_id = 0
|
||||
if end_frame_id is None:
|
||||
end_frame_id = len(video)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||
if "smoother_configs" in config:
|
||||
if self.in_streamlit: st.markdown("Loading smoother ...")
|
||||
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
||||
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
||||
else:
|
||||
smoother = None
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
if self.in_streamlit: st.video(video_file.read())
|
||||
175
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
175
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance_xl
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
# TODO: SDXL ControlNet
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||
if "ipadapter_xl" in model_manager.model:
|
||||
self.ipadapter = model_manager.ipadapter_xl
|
||||
if "ipadapter_xl_image_encoder" in model_manager.model:
|
||||
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
pipe.fetch_ipadapter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
190
diffsynth/pipelines/stable_diffusion_xl_video.py
Normal file
190
diffsynth/pipelines/stable_diffusion_xl_video.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
|
||||
from .dancer import lets_dance_xl
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# TODO: SDXL ControlNet
|
||||
self.motion_modules: SDXLMotionModel = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||
if "motion_modules_xl" in model_manager.model:
|
||||
self.motion_modules = model_manager.motion_modules_xl
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
use_animatediff="motion_modules_xl" in model_manager.model
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||
latents.append(latent)
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
vram_limit_level=0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_images(latents.to(torch.float32))
|
||||
|
||||
return image
|
||||
@@ -1,6 +1,5 @@
|
||||
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
|
||||
from ..schedulers import ContinuousODEScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
@@ -9,11 +8,13 @@ from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class SVDVideoPipeline(BasePipeline):
|
||||
class SVDVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__()
|
||||
self.scheduler = ContinuousODEScheduler()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.image_encoder: SVDImageEncoder = None
|
||||
self.unet: SVDUNet = None
|
||||
@@ -21,23 +22,32 @@ class SVDVideoPipeline(BasePipeline):
|
||||
self.vae_decoder: SVDVAEDecoder = None
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
|
||||
self.unet = model_manager.fetch_model("svd_unet")
|
||||
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
|
||||
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.image_encoder = model_manager.image_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, **kwargs):
|
||||
pipe = SVDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype
|
||||
)
|
||||
pipe.fetch_models(model_manager)
|
||||
pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def encode_image_with_clip(self, image):
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
|
||||
@@ -1,6 +0,0 @@
|
||||
from .prompt_refiners import Translator, BeautifulPrompt
|
||||
from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .sd3_prompter import SD3Prompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
from .kolors_prompter import KolorsPrompter
|
||||
@@ -1,57 +0,0 @@
|
||||
from ..models.model_manager import ModelManager
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length if max_length is None else max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
|
||||
class BasePrompter:
|
||||
def __init__(self, refiners=[]):
|
||||
self.refiners = refiners
|
||||
|
||||
|
||||
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
|
||||
for refiner_class in refiner_classes:
|
||||
refiner = refiner_class.from_model_manager(model_nameger)
|
||||
self.refiners.append(refiner)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
if isinstance(prompt, list):
|
||||
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
||||
else:
|
||||
for refiner in self.refiners:
|
||||
prompt = refiner(prompt, positive=positive)
|
||||
return prompt
|
||||
@@ -1,353 +0,0 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.model_manager import ModelManager
|
||||
import json, os, re
|
||||
from typing import List, Optional, Union, Dict
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.utils import PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
# reload tokenizer
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
self.bos_id: int = self.sp_model.bos_id()
|
||||
self.eos_id: int = self.sp_model.eos_id()
|
||||
self.pad_id: int = self.sp_model.unk_id()
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
self.special_tokens = {}
|
||||
self.index_special_tokens = {}
|
||||
for token in special_tokens:
|
||||
self.special_tokens[token] = self.n_words
|
||||
self.index_special_tokens[self.n_words] = token
|
||||
self.n_words += 1
|
||||
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
||||
|
||||
def tokenize(self, s: str, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
last_index = 0
|
||||
t = []
|
||||
for match in re.finditer(self.role_special_token_expression, s):
|
||||
if last_index < match.start():
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
||||
t.append(s[match.start():match.end()])
|
||||
last_index = match.end()
|
||||
if last_index < len(s):
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
||||
return t
|
||||
else:
|
||||
return self.sp_model.EncodeAsPieces(s)
|
||||
|
||||
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||
assert type(s) is str
|
||||
t = self.sp_model.encode(s)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
text, buffer = "", []
|
||||
for token in t:
|
||||
if token in self.index_special_tokens:
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
buffer = []
|
||||
text += self.index_special_tokens[token]
|
||||
else:
|
||||
buffer.append(token)
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self.sp_model.DecodePieces(tokens)
|
||||
return text
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.index_special_tokens:
|
||||
return self.index_special_tokens[index]
|
||||
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||
return ""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
||||
**kwargs):
|
||||
self.name = "GLMTokenizer"
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.tokenizer = SPTokenizer(vocab_file)
|
||||
self.special_tokens = {
|
||||
"<bos>": self.tokenizer.bos_id,
|
||||
"<eos>": self.tokenizer.eos_id,
|
||||
"<pad>": self.tokenizer.pad_id
|
||||
}
|
||||
self.encode_special_tokens = encode_special_tokens
|
||||
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
encode_special_tokens=encode_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||
return self.tokenizer.special_tokens[token]
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_words
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.tokenizer.convert_token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.tokenizer.convert_id_to_token(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.tokenizer.decode_tokens(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def get_prefix_tokens(self):
|
||||
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
||||
return prefix_tokens
|
||||
|
||||
def build_single_message(self, role, metadata, message):
|
||||
assert role in ["system", "user", "assistant", "observation"], role
|
||||
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||
message_tokens = self.tokenizer.encode(message)
|
||||
tokens = role_tokens + message_tokens
|
||||
return tokens
|
||||
|
||||
def build_chat_input(self, query, history=None, role="user"):
|
||||
if history is None:
|
||||
history = []
|
||||
input_ids = []
|
||||
for item in history:
|
||||
content = item["content"]
|
||||
if item["role"] == "system" and "tools" in item:
|
||||
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||
input_ids.extend(self.build_single_message(role, "", query))
|
||||
input_ids.extend([self.get_command("<|assistant|>")])
|
||||
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
prefix_tokens = self.get_prefix_tokens()
|
||||
token_ids_0 = prefix_tokens + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * seq_length
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = list(range(seq_length))
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
|
||||
class KolorsPrompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
|
||||
self.text_encoder: ChatGLMModel = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: ChatGLMModel = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
|
||||
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
output = text_encoder(
|
||||
input_ids=text_inputs['input_ids'] ,
|
||||
attention_mask=text_inputs['attention_mask'],
|
||||
position_ids=text_inputs['position_ids'],
|
||||
output_hidden_states=True
|
||||
)
|
||||
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
|
||||
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
|
||||
|
||||
return pooled_prompt_emb, prompt_emb
|
||||
@@ -1,77 +0,0 @@
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.model_manager import ModelManager
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class BeautifulPrompt(torch.nn.Module):
|
||||
def __init__(self, tokenizer_path=None, model=None, template=""):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = template
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_nameger: ModelManager):
|
||||
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
|
||||
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
if model_path.endswith("v2"):
|
||||
template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
beautiful_prompt = BeautifulPrompt(
|
||||
tokenizer_path=model_path,
|
||||
model=model,
|
||||
template=template
|
||||
)
|
||||
return beautiful_prompt
|
||||
|
||||
|
||||
def __call__(self, raw_prompt, positive=True, **kwargs):
|
||||
if positive:
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
|
||||
return prompt
|
||||
else:
|
||||
return raw_prompt
|
||||
|
||||
|
||||
|
||||
class Translator(torch.nn.Module):
|
||||
def __init__(self, tokenizer_path=None, model=None):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_nameger: ModelManager):
|
||||
model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
|
||||
translator = Translator(tokenizer_path=model_path, model=model)
|
||||
return translator
|
||||
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
print(f"Your prompt is translated: {prompt}")
|
||||
return prompt
|
||||
@@ -1,92 +0,0 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
import os, torch
|
||||
|
||||
|
||||
class SD3Prompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
tokenizer_2_path=None,
|
||||
tokenizer_3_path=None
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
|
||||
if tokenizer_3_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: SD3TextEncoder2 = None
|
||||
self.text_encoder_3: SD3TextEncoder3 = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
self.text_encoder_3 = text_encoder_3
|
||||
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids.to(device)
|
||||
pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
|
||||
return pooled_prompt_emb, prompt_emb
|
||||
|
||||
|
||||
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
).input_ids.to(device)
|
||||
prompt_emb = text_encoder(input_ids)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
|
||||
|
||||
# T5
|
||||
if self.text_encoder_3 is None:
|
||||
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
else:
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.cat([
|
||||
torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
|
||||
prompt_emb_3
|
||||
], dim=-2)
|
||||
pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
@@ -1,73 +0,0 @@
|
||||
from .base_prompter import BasePrompter, tokenize_long_prompt
|
||||
from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings
|
||||
from ..models import SDTextEncoder
|
||||
from transformers import CLIPTokenizer
|
||||
import torch, os
|
||||
|
||||
|
||||
|
||||
class SDPrompter(BasePrompter):
|
||||
def __init__(self, tokenizer_path=None):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.textual_inversion_dict = {}
|
||||
self.keyword_dict = {}
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: SDTextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
|
||||
def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
|
||||
dtype = next(iter(text_encoder.parameters())).dtype
|
||||
state_dict = text_encoder.token_embedding.state_dict()
|
||||
token_embeddings = [state_dict["weight"]]
|
||||
for keyword in textual_inversion_dict:
|
||||
_, embeddings = textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["weight"] = token_embeddings
|
||||
text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
|
||||
text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
|
||||
text_encoder.token_embedding.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
|
||||
def load_textual_inversions(self, model_paths):
|
||||
for model_path in model_paths:
|
||||
keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
|
||||
state_dict = load_state_dict(model_path)
|
||||
|
||||
# Search for embeddings
|
||||
for embeddings in search_for_embeddings(state_dict):
|
||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
|
||||
self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
|
||||
self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
print(f"Textual inversion {keyword} is enabled.")
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
@@ -1,61 +0,0 @@
|
||||
from .base_prompter import BasePrompter, tokenize_long_prompt
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from transformers import CLIPTokenizer
|
||||
import torch, os
|
||||
|
||||
|
||||
|
||||
class SDXLPrompter(BasePrompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None,
|
||||
tokenizer_2_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
|
||||
max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
|
||||
prompt_emb_1 = prompt_emb_1[: max_batch_size]
|
||||
prompt_emb_2 = prompt_emb_2[: max_batch_size]
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
3
diffsynth/prompts/__init__.py
Normal file
3
diffsynth/prompts/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
@@ -1,34 +1,19 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from transformers import BertTokenizer, AutoTokenizer
|
||||
import warnings, os
|
||||
from .utils import Prompter
|
||||
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
|
||||
import warnings
|
||||
|
||||
|
||||
class HunyuanDiTPrompter(BasePrompter):
|
||||
class HunyuanDiTPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None,
|
||||
tokenizer_t5_path=None
|
||||
tokenizer_path="configs/hunyuan_dit/tokenizer",
|
||||
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||
if tokenizer_t5_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
|
||||
super().__init__()
|
||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
|
||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder_t5 = text_encoder_t5
|
||||
|
||||
|
||||
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
@@ -52,6 +37,8 @@ class HunyuanDiTPrompter(BasePrompter):
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: BertModel,
|
||||
text_encoder_t5: T5EncoderModel,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
@@ -61,9 +48,9 @@ class HunyuanDiTPrompter(BasePrompter):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
||||
|
||||
# T5
|
||||
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
||||
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
||||
|
||||
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
|
||||
17
diffsynth/prompts/sd_prompter.py
Normal file
17
diffsynth/prompts/sd_prompter.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
43
diffsynth/prompts/sdxl_prompter.py
Normal file
43
diffsynth/prompts/sdxl_prompter.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
123
diffsynth/prompts/utils.py
Normal file
123
diffsynth/prompts/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import ModelManager
|
||||
import os
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt, max_length=99999999):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = max_length
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
@@ -1,3 +1,2 @@
|
||||
from .ddim import EnhancedDDIMScheduler
|
||||
from .continuous_ode import ContinuousODEScheduler
|
||||
from .flow_match import FlowMatchScheduler
|
||||
|
||||
@@ -22,10 +22,10 @@ class EnhancedDDIMScheduler():
|
||||
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||
if num_inference_steps == 1:
|
||||
self.timesteps = torch.Tensor([max_timestep])
|
||||
self.timesteps = [max_timestep]
|
||||
else:
|
||||
step_length = max_timestep / (num_inference_steps - 1)
|
||||
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
||||
self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
|
||||
|
||||
|
||||
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||
@@ -43,37 +43,31 @@ class EnhancedDDIMScheduler():
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
timestep_id = self.timesteps.index(timestep)
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
alpha_prod_t_prev = 1.0
|
||||
else:
|
||||
timestep_prev = int(self.timesteps[timestep_id + 1])
|
||||
timestep_prev = self.timesteps[timestep_id + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||
|
||||
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
if self.prediction_type == "epsilon":
|
||||
return noise
|
||||
else:
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
# This scheduler doesn't support this function.
|
||||
pass
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
Binary file not shown.
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"name_or_path": "THUDM/chatglm3-6b-base",
|
||||
"remove_space": false,
|
||||
"do_lower_case": false,
|
||||
"tokenizer_class": "ChatGLMTokenizer",
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenization_chatglm.ChatGLMTokenizer",
|
||||
null
|
||||
]
|
||||
}
|
||||
}
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,125 +0,0 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user