Merge pull request #2 from Artiprocher/dev

v1.1
This commit is contained in:
Artiprocher
2023-12-23 20:44:06 +08:00
committed by GitHub
34 changed files with 2346 additions and 457 deletions

15
DiffSynth_Studio.py Normal file
View File

@@ -0,0 +1,15 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Diasble virtual VRAM on windows system
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
st.markdown("""
# DiffSynth Studio
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
Welcome to DiffSynth Studio.
""")

View File

@@ -1,53 +0,0 @@
# DiffSynth Studio
## 介绍
DiffSynth 是一个全新的 Diffusion 引擎,我们重构了 Text Encoder、UNet、VAE 等架构,保持与开源社区模型兼容性的同时,提升了计算性能。目前这个版本仅仅是一个初始版本,实现了文生图和图生图功能,支持 SD 和 SDXL 架构。未来我们计划基于这个全新的代码库开发更多有趣的功能。
## 安装
如果你只想在 Python 代码层面调用 DiffSynth Studio你只需要安装 `torch`(深度学习框架)和 `transformers`(仅用于实现分词器)。
```
pip install torch transformers
```
如果你想使用 UI还需要额外安装 `streamlit`(一个 webui 框架)和 `streamlit-drawable-canvas`(用于图生图画板)。
```
pip install streamlit streamlit-drawable-canvas
```
## 使用
通过 Python 代码调用
```python
from diffsynth.models import ModelManager
from diffsynth.prompts import SDPrompter, SDXLPrompter
from diffsynth.pipelines import SDPipeline, SDXLPipeline
model_manager = ModelManager()
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
prompter = SDPrompter()
pipe = SDPipeline()
prompt = "a girl"
negative_prompt = ""
image = pipe(
model_manager, prompter,
prompt, negative_prompt=negative_prompt,
num_inference_steps=20, height=512, width=512,
)
image.save("image.png")
```
如果需要用 SDXL 架构模型,请把 `SDPrompter``SDPipeline` 换成 `SDXLPrompter`, `SDXLPipeline`
当然,你也可以使用我们提供的 UI但请注意我们的 UI 程序很简单,且未来可能会大幅改变。
```
python -m streamlit run Diffsynth_Studio.py
```

View File

@@ -1,53 +1,59 @@
# DiffSynth Studio
## 介绍
## Introduction
DiffSynth is a new Diffusion engine. We have restructured architectures like Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting text-to-image and image-to-image functionalities, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
DiffSynth is a new Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
## 安装
## Installation
If you only want to use DiffSynth Studio at the Python code level, you just need to install torch (a deep learning framework) and transformers (only used for implementing a tokenizer).
Create Python environment:
```
pip install torch transformers
conda env create -f environment.yml
```
If you wish to use the UI, you'll also need to additionally install `streamlit` (a web UI framework) and `streamlit-drawable-canvas` (used for the image-to-image canvas).
Enter the Python environment:
```
pip install streamlit streamlit-drawable-canvas
conda activate DiffSynthStudio
```
## 使用
Use DiffSynth Studio in Python
```python
from diffsynth.models import ModelManager
from diffsynth.prompts import SDPrompter, SDXLPrompter
from diffsynth.pipelines import SDPipeline, SDXLPipeline
model_manager = ModelManager()
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
prompter = SDPrompter()
pipe = SDPipeline()
prompt = "a girl"
negative_prompt = ""
image = pipe(
model_manager, prompter,
prompt, negative_prompt=negative_prompt,
num_inference_steps=20, height=512, width=512,
)
image.save("image.png")
```
If you want to use SDXL architecture models, replace `SDPrompter` and `SDPipeline` with `SDXLPrompter` and `SDXLPipeline`, respectively.
Of course, you can also use the UI we provide. The UI is simple but may be changed in the future.
## Usage (in WebUI)
```
python -m streamlit run Diffsynth_Studio.py
```
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
## Usage (in Python code)
### Example 1: Stable Diffusion
We can generate images with very high resolution. Please see `examples/sd_text_to_image.py` for more details.
|512*512|1024*1024|2048*2048|4096*4096|
|-|-|-|-|
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
### Example 2: Stable Diffusion XL
Generate images with Stable Diffusion XL. Please see `examples/sdxl_text_to_image.py` for more details.
|1024*1024|2048*2048|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
### Example 3: Stable Diffusion XL Turbo
Generate images with Stable Diffusion XL Turbo. You can see `examples/sdxl_turbo.py` for more details, but we highly recommend you to use it in the WebUI.
|"black car"|"red car"|
|-|-|
|![black_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7fbfd803-68d4-44f3-8713-8c925fec47d0)|![black_car_to_red_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/aaf886e4-c33c-4fd8-98e2-29eef117ba00)|
### Example 4: Toon Shading
A very interesting example. Please see `examples/sd_toon_shading.py` for more details.
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/53532f0e-39b1-4791-b920-c975d52ec24a

6
diffsynth/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
from .data import *
from .models import *
from .prompts import *
from .schedulers import *
from .pipelines import *
from .controlnets import *

View File

@@ -0,0 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
from .processors import Annotator

View File

@@ -0,0 +1,55 @@
import torch
import numpy as np
from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
class ControlNetUnit:
def __init__(self, processor, model, scale=1.0):
self.processor = processor
self.model = model
self.scale = scale
class MultiControlNetManager:
def __init__(self, controlnet_units=[]):
self.processors = [unit.processor for unit in controlnet_units]
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def process_image(self, image, return_image=False):
processed_image = [
processor(image)
for processor in self.processors
]
if return_image:
return processed_image
processed_image = torch.concat([
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
for image_ in processed_image
], dim=0)
return processed_image
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32
):
res_stack = None
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack

View File

@@ -0,0 +1,51 @@
from typing_extensions import Literal, TypeAlias
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
)
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path)
elif processor_id == "tile":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def __call__(self, image):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {
"include_body": True,
"include_hand": True,
"include_face": True
}
else:
kwargs = {}
if self.processor is not None:
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
return image

View File

@@ -0,0 +1 @@
from .video import VideoData, save_video, save_frames

148
diffsynth/data/video.py Normal file
View File

@@ -0,0 +1,148 @@
import imageio, os
import numpy as np
from PIL import Image
from tqdm import tqdm
class LowMemoryVideo:
def __init__(self, file_name):
self.reader = imageio.get_reader(file_name)
def __len__(self):
return self.reader.count_frames()
def __getitem__(self, item):
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
def __del__(self):
self.reader.close()
def split_file_name(file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
result.append(i)
if number != -1:
result.append(number)
result = tuple(result)
return result
def search_for_images(folder):
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
file_list = [i[1] for i in sorted(file_list)]
file_list = [os.path.join(folder, i) for i in file_list]
return file_list
class LowMemoryImageFolder:
def __init__(self, folder, file_list=None):
if file_list is None:
self.file_list = search_for_images(folder)
else:
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
return Image.open(self.file_list[item]).convert("RGB")
def __del__(self):
pass
def crop_and_resize(image, height, width):
image = np.array(image)
image_height, image_width, _ = image.shape
if image_height / image_width < height / width:
croped_width = int(image_height / height * width)
left = (image_width - croped_width) // 2
image = image[:, left: left+croped_width]
image = Image.fromarray(image).resize((width, height))
else:
croped_height = int(image_width / width * height)
left = (image_height - croped_height) // 2
image = image[left: left+croped_height, :]
image = Image.fromarray(image).resize((width, height))
return image
class VideoData:
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
if video_file is not None:
self.data_type = "video"
self.data = LowMemoryVideo(video_file, **kwargs)
elif image_folder is not None:
self.data_type = "images"
self.data = LowMemoryImageFolder(image_folder, **kwargs)
else:
raise ValueError("Cannot open video or image folder")
self.length = None
self.set_shape(height, width)
def raw_data(self):
frames = []
for i in range(self.__len__()):
frames.append(self.__getitem__(i))
return frames
def set_length(self, length):
self.length = length
def set_shape(self, height, width):
self.height = height
self.width = width
def __len__(self):
if self.length is None:
return len(self.data)
else:
return self.length
def shape(self):
if self.height is not None and self.width is not None:
return self.height, self.width
else:
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):
frame = self.data.__getitem__(item)
width, height = frame.size
if self.height is not None and self.width is not None:
if self.height != height or self.width != width:
frame = crop_and_resize(frame, self.height, self.width)
return frame
def __del__(self):
pass
def save_images(self, folder):
os.makedirs(folder, exist_ok=True)
for i in tqdm(range(self.__len__()), desc="Saving images"):
frame = self.__getitem__(i)
frame.save(os.path.join(folder, f"{i}.png"))
def save_video(frames, save_path, fps, quality=9):
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame)
writer.append_data(frame)
writer.close()
def save_frames(frames, save_path):
os.makedirs(save_path, exist_ok=True)
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
frame.save(os.path.join(save_path, f"{i}.png"))

View File

@@ -1,4 +1,4 @@
import torch
import torch, os
from safetensors import safe_open
from .sd_text_encoder import SDTextEncoder
@@ -11,21 +11,44 @@ from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from transformers import AutoModelForCausalLM
class ModelManager:
def __init__(self, torch_type=torch.float16, device="cuda"):
self.torch_type = torch_type
def __init__(self, torch_dtype=torch.float16, device="cuda"):
self.torch_dtype = torch_dtype
self.device = device
self.model = {}
self.model_path = {}
self.textual_inversion_dict = {}
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
return True
if self.is_stabe_diffusion_xl(state_dict):
return False
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
return param_name in state_dict
def load_stable_diffusion(self, state_dict, components=None):
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
return param_name in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
return param_name in state_dict
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
"unet": SDUNet,
@@ -36,18 +59,30 @@ class ModelManager:
if components is None:
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_type).to(self.device)
if component == "text_encoder":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_xl(self, state_dict, components=None):
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDXLTextEncoder,
"text_encoder_2": SDXLTextEncoder2,
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
@@ -60,18 +95,97 @@ class ModelManager:
# I do not know how to solve this problem.
self.model[component].to(torch.float32).to(self.device)
else:
self.model[component].to(self.torch_type).to(self.device)
def load_from_safetensors(self, file_path, components=None):
state_dict = load_state_dict_from_safetensors(file_path)
if self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components)
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_controlnet(self, state_dict, file_path=""):
component = "controlnet"
if component not in self.model:
self.model[component] = []
self.model_path[component] = []
model = SDControlNet()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component].append(model)
self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path=""):
component = "motion_modules"
model = SDMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
).to(self.device).eval()
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += self.search_for_embeddings(state_dict[k])
return embeddings
def load_textual_inversions(self, folder):
# Store additional tokens here
self.textual_inversion_dict = {}
# Load every textual inversion file
for file_name in os.listdir(folder):
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
# Search for embeddings
for embeddings in self.search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
break
def load_model(self, file_path, components=None):
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
elif self.is_controlnet(state_dict):
self.load_controlnet(state_dict, file_path=file_path)
elif self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
def load_models(self, file_path_list):
for file_path in file_path_list:
self.load_model(file_path)
def to(self, device):
for component in self.model:
self.model[component].to(device)
if isinstance(self.model[component], list):
for model in self.model[component]:
model.to(device)
else:
self.model[component].to(device)
def get_model_with_model_path(self, model_path):
for component in self.model_path:
if isinstance(self.model_path[component], str):
if os.path.samefile(self.model_path[component], model_path):
return self.model[component]
elif isinstance(self.model_path[component], list):
for i, model_path_ in enumerate(self.model_path[component]):
if os.path.samefile(model_path_, model_path):
return self.model[component][i]
raise ValueError(f"Please load model {model_path} before you use it.")
def __getattr__(self, __name):
if __name in self.model:
@@ -80,16 +194,28 @@ class ModelManager:
return super.__getattribute__(__name)
def load_state_dict_from_safetensors(file_path):
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path):
return torch.load(file_path, map_location="cpu")
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
return state_dict
def search_parameter(param, state_dict):

View File

@@ -1,4 +1,15 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
@@ -15,7 +26,7 @@ class Attention(torch.nn.Module):
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -30,9 +41,36 @@ class Attention(torch.nn.Module):
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask)

View File

@@ -0,0 +1,584 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .tiler import TileWorker
class ControlNetConditioningLayer(torch.nn.Module):
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
super().__init__()
self.blocks = torch.nn.ModuleList([])
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
self.blocks.append(torch.nn.SiLU())
for i in range(1, len(channels) - 2):
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
self.blocks.append(torch.nn.SiLU())
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
self.blocks.append(torch.nn.SiLU())
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
def forward(self, conditioning):
for block in self.blocks:
conditioning = block(conditioning)
return conditioning
class SDControlNet(torch.nn.Module):
def __init__(self, global_pool=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
self.blocks = torch.nn.ModuleList([
# CrossAttnDownBlock2D
ResnetBlock(320, 320, 1280),
AttentionBlock(8, 40, 320, 1, 768),
PushBlock(),
ResnetBlock(320, 320, 1280),
AttentionBlock(8, 40, 320, 1, 768),
PushBlock(),
DownSampler(320),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(320, 640, 1280),
AttentionBlock(8, 80, 640, 1, 768),
PushBlock(),
ResnetBlock(640, 640, 1280),
AttentionBlock(8, 80, 640, 1, 768),
PushBlock(),
DownSampler(640),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(640, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
PushBlock(),
DownSampler(1280),
PushBlock(),
# DownBlock2D
ResnetBlock(1280, 1280, 1280),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
PushBlock(),
# UNetMidBlock2DCrossAttn
ResnetBlock(1280, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
ResnetBlock(1280, 1280, 1280),
PushBlock()
])
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
])
self.global_pool = global_pool
def forward(
self,
sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32,
):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
# pool
if self.global_pool:
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
return controlnet_res_stack
def state_dict_converter(self):
return SDControlNetStateDictConverter()
class SDControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
]
# controlnet_rename_dict
controlnet_rename_dict = {
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
}
# Rename each parameter
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
pass
elif name in controlnet_rename_dict:
names = controlnet_rename_dict[name].split(".")
elif names[0] == "controlnet_down_blocks":
names[0] = "controlnet_blocks"
elif names[0] == "controlnet_mid_block":
names = ["controlnet_blocks", "12", names[-1]]
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
if names[0] == "mid_block":
names.insert(1, "0")
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
block_type_with_id = ".".join(names[:4])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:4])
names = ["blocks", str(block_id[block_type])] + names[4:]
if "ff" in names:
ff_index = names.index("ff")
component = ".".join(names[ff_index:ff_index+3])
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
names = names[:ff_index] + [component] + names[ff_index+3:]
if "to_out" in names:
names.pop(names.index("to_out") + 1)
else:
raise ValueError(f"Unknown parameters: {name}")
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
if rename_dict[name] in [
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
]:
continue
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"control_model.time_embed.0.weight": "time_embedding.0.weight",
"control_model.time_embed.0.bias": "time_embedding.0.bias",
"control_model.time_embed.2.weight": "time_embedding.2.weight",
"control_model.time_embed.2.bias": "time_embedding.2.bias",
"control_model.input_blocks.0.0.weight": "conv_in.weight",
"control_model.input_blocks.0.0.bias": "conv_in.bias",
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -0,0 +1,198 @@
from .sd_unet import SDUNet, Attention, GEGLU
import torch
from einops import rearrange, repeat
class TemporalTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
super().__init__()
# 1. Self-Attn
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 2. Cross-Attn
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 3. Feed-forward
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.act_fn = GEGLU(dim, dim * 4)
self.ff = torch.nn.Linear(dim * 4, dim)
def forward(self, hidden_states, batch_size=1):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.act_fn(norm_hidden_states)
ff_output = self.ff(ff_output)
hidden_states = ff_output + hidden_states
return hidden_states
class TemporalBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
self.transformer_blocks = torch.nn.ModuleList([
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim
)
for d in range(num_layers)
])
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
batch_size=batch_size
)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class SDMotionModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
])
self.call_block_id = {
1: 0,
4: 1,
9: 2,
12: 3,
17: 4,
20: 5,
24: 6,
26: 7,
29: 8,
32: 9,
34: 10,
36: 11,
40: 12,
43: 13,
46: 14,
50: 15,
53: 16,
56: 17,
60: 18,
63: 19,
66: 20
}
def forward(self):
pass
def state_dict_converter(self):
return SDMotionModelStateDictConverter()
class SDMotionModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"norm": "norm",
"proj_in": "proj_in",
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
"proj_out": "proj_out",
}
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
state_dict_ = {}
last_prefix, module_id = "", -1
for name in name_list:
names = name.split(".")
prefix_index = names.index("temporal_transformer") + 1
prefix = ".".join(names[:prefix_index])
if prefix != last_prefix:
last_prefix = prefix
module_id += 1
middle_name = ".".join(names[prefix_index:-1])
suffix = names[-1]
if "pos_encoder" in names:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
else:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
state_dict_[rename] = state_dict[name]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -279,27 +279,19 @@ class SDUNet(torch.nn.Module):
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, **kwargs):
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled:
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
block, hidden_states, time_emb, text_emb, res_stack,
height, width, tile_size, tile_stride
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -308,23 +300,6 @@ class SDUNet(torch.nn.Module):
return hidden_states
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
resize_scale = inter_height / height
hidden_states = Tiler()(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
inter_device=hidden_states.device,
inter_dtype=hidden_states.dtype
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
return hidden_states, time_emb, text_emb, res_stack
def state_dict_converter(self):
return SDUNetStateDictConverter()

View File

@@ -1,7 +1,7 @@
import torch
from .attention import Attention
from .sd_unet import ResnetBlock, UpSampler
from .tiler import Tiler
from .tiler import TileWorker
class VAEAttentionBlock(torch.nn.Module):
@@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module):
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states

View File

@@ -1,7 +1,7 @@
import torch
from .sd_unet import ResnetBlock, DownSampler
from .sd_vae_decoder import VAEAttentionBlock
from .tiler import Tiler
from .tiler import TileWorker
class SDVAEEncoder(torch.nn.Module):
@@ -38,11 +38,13 @@ class SDVAEEncoder(torch.nn.Module):
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states

View File

@@ -1,4 +1,5 @@
import torch
from einops import rearrange, repeat
class Tiler(torch.nn.Module):
@@ -70,6 +71,106 @@ class Tiler(torch.nn.Module):
return x
class TileWorker:
def __init__(self):
pass
def mask(self, height, width, border_width):
# Create a mask with shape (height, width).
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / border_width).clip(0, 1)
return mask
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
batch_size, channel, _, _ = model_input.shape
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
unfold_operator = torch.nn.Unfold(
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
model_input = unfold_operator(model_input)
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
return model_input
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
# Call y=forward_fn(x) for each tile
tile_num = model_input.shape[-1]
model_output_stack = []
for tile_id in range(0, tile_num, tile_batch_size):
# process input
tile_id_ = min(tile_id + tile_batch_size, tile_num)
x = model_input[:, :, :, :, tile_id: tile_id_]
x = x.to(device=inference_device, dtype=inference_dtype)
x = rearrange(x, "b c h w n -> (n b) c h w")
# process output
y = forward_fn(x)
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
y = y.to(device=tile_device, dtype=tile_dtype)
model_output_stack.append(y)
model_output = torch.concat(model_output_stack, dim=-1)
return model_output
def io_scale(self, model_output, tile_size):
# Determine the size modification happend in forward_fn
# We only consider the same scale on height and width.
io_scale = model_output.shape[2] / tile_size
return io_scale
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
# The reversed function of tile
mask = self.mask(tile_size, tile_size, border_width)
mask = mask.to(device=tile_device, dtype=tile_dtype)
mask = rearrange(mask, "h w -> 1 1 h w 1")
model_output = model_output * mask
fold_operator = torch.nn.Fold(
output_size=(height, width),
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
model_output = fold_operator(model_output) / fold_operator(mask)
return model_output
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
inference_device, inference_dtype = model_input.device, model_input.dtype
height, width = model_input.shape[2], model_input.shape[3]
border_width = int(tile_stride*0.5) if border_width is None else border_width
# tile
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
# inference
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
# resize
io_scale = self.io_scale(model_output, tile_size)
height, width = int(height*io_scale), int(width*io_scale)
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
border_width = int(border_width*io_scale)
# untile
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output

View File

@@ -1,2 +1,3 @@
from .stable_diffusion import SDPipeline
from .stable_diffusion_xl import SDXLPipeline
from .stable_diffusion import SDImagePipeline
from .stable_diffusion_xl import SDXLImagePipeline
from .stable_diffusion_video import SDVideoPipeline

View File

@@ -0,0 +1,113 @@
import torch
from ..models import SDUNet, SDMotionModel
from ..models.sd_unet import PushBlock, PopBlock
from ..models.tiler import TileWorker
from ..controlnets import MultiControlNetManager
def lets_dance(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
tiled=False,
tile_size=64,
tile_stride=32,
device = "cuda",
vram_limit_level = 0,
):
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
controlnet_insert_block_id = 30
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
# 4. blocks
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
if tiled:
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
hidden_states_input[batch_id: batch_id_],
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
if vram_limit_level>=1:
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
else:
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states

View File

@@ -1,32 +1,90 @@
from ..models import ModelManager
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDPipeline(torch.nn.Module):
class SDImagePipeline(torch.nn.Module):
def __init__(self):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id),
model_manager.get_model_with_model_path(config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
model_manager: ModelManager,
prompter: SDPrompter,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
init_image=None,
input_image=None,
controlnet_image=None,
denoising_strength=1.0,
height=512,
width=512,
@@ -37,39 +95,54 @@ class SDPipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
prompt_emb = prompter.encode_prompt(model_manager.text_encoder, prompt, clip_skip=clip_skip, device=model_manager.device)
negative_prompt_emb = prompter.encode_prompt(model_manager.text_encoder, negative_prompt, clip_skip=clip_skip, device=model_manager.device)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(device=model_manager.device, dtype=model_manager.torch_type)
latents = model_manager.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
# Prepare ControlNets
if controlnet_image is not None:
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_cond = model_manager.unet(latents, timestep, prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise_pred_uncond = model_manager.unet(latents, timestep, negative_prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred_posi = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
device=self.device, vram_limit_level=0
)
noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
device=self.device, vram_limit_level=0
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,217 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
animatediff_batch_size = 16,
animatediff_stride = 8,
device = "cuda",
vram_limit_level = 0,
):
num_frames = sample.shape[0]
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
for batch_id in range(0, num_frames, animatediff_stride):
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
# process this batch
hidden_states_batch = lets_dance(
unet, motion_modules, controlnet,
sample[batch_id: batch_id_].to(device),
timestep,
encoder_hidden_states[batch_id: batch_id_].to(device),
controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, device=device, vram_limit_level=vram_limit_level
).cpu()
# update hidden_states
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1) / 2), 1e-2)
hidden_states, num = hidden_states_output[i]
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
hidden_states_output[i] = (hidden_states, num + 1)
# output
hidden_states = torch.stack([h for h, _ in hidden_states_output])
return hidden_states
class SDVideoPipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
self.prompter = SDPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.motion_modules: SDMotionModel = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id),
model_manager.get_model_with_model_path(config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_motion_modules(self, model_manager: ModelManager):
if "motion_modules" in model_manager.model:
self.motion_modules = model_manager.motion_modules
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
use_animatediff="motion_modules" in model_manager.model
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
latents.append(latent)
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
vram_limit_level=0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None:
latents = noise
else:
latents = self.encode_images(input_frames)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
# Prepare ControlNets
if controlnet_frames is not None:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
output_frames = self.decode_images(latents)
return output_frames

View File

@@ -1,4 +1,5 @@
from ..models import ModelManager
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder
# TODO: SDXL ControlNet
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
@@ -7,29 +8,77 @@ from PIL import Image
import numpy as np
class SDXLPipeline(torch.nn.Module):
class SDXLImagePipeline(torch.nn.Module):
def __init__(self):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# TODO: SDXL ControlNet
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.text_encoder_2 = model_manager.text_encoder_2
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
# TODO: SDXL ControlNet
pass
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
pipe = SDXLImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
model_manager: ModelManager,
prompter: SDXLPrompter,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
init_image=None,
input_image=None,
controlnet_image=None,
denoising_strength=1.0,
refining_strength=0.0,
height=1024,
width=1024,
num_inference_steps=20,
@@ -39,76 +88,62 @@ class SDXLPipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
add_text_embeds, prompt_emb = prompter.encode_prompt(
model_manager.text_encoder,
model_manager.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=model_manager.device
)
if cfg_scale != 1.0:
negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt(
model_manager.text_encoder,
model_manager.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=model_manager.device
)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(
device=model_manager.device, dtype=model_manager.torch_type
)
latents = model_manager.vae_encoder(
image.to(torch.float32),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise = torch.randn(
(1, 4, height//8, width//8),
device=model_manager.device, dtype=model_manager.torch_type
)
latents = self.scheduler.add_noise(
latents.to(model_manager.torch_type),
noise,
timestep=self.scheduler.timesteps[0]
)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device)
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
if timestep >= 1000 * refining_strength:
denoising_model = model_manager.unet
else:
denoising_model = model_manager.refiner
if cfg_scale != 1.0:
noise_pred_cond = denoising_model(
latents, timestep, prompt_emb,
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
noise_pred_posi = self.unet(
latents, timestep, prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise_pred_uncond = denoising_model(
latents, timestep, negative_prompt_emb,
add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds,
noise_pred_nega = self.unet(
latents, timestep, prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = denoising_model(
latents, timestep, prompt_emb,
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
noise_pred = self.unet(
latents, timestep, prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
@@ -118,9 +153,6 @@ class SDXLPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = latents.to(torch.float32)
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -1,7 +1,6 @@
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
import torch, os
from safetensors import safe_open
def tokenize_long_prompt(tokenizer, prompt):
@@ -36,49 +35,75 @@ def tokenize_long_prompt(tokenizer, prompt):
return input_ids
def load_textual_inversion(prompt):
# TODO: This module is not enabled now.
textual_inversion_files = os.listdir("models/textual_inversion")
embeddings_768 = []
embeddings_1280 = []
for file_name in textual_inversion_files:
if not file_name.endswith(".safetensors"):
continue
keyword = file_name[:-len(".safetensors")]
if keyword in prompt:
prompt = prompt.replace(keyword, "")
with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f:
for k in f.keys():
embedding = f.get_tensor(k).to(torch.float32)
if embedding.shape[-1] == 768:
embeddings_768.append(embedding)
elif embedding.shape[-1] == 1280:
embeddings_1280.append(embedding)
if len(embeddings_768)==0:
embeddings_768 = torch.zeros((0, 768))
else:
embeddings_768 = torch.concat(embeddings_768, dim=0)
if len(embeddings_1280)==0:
embeddings_1280 = torch.zeros((0, 1280))
else:
embeddings_1280 = torch.concat(embeddings_1280, dim=0)
return prompt, embeddings_768, embeddings_1280
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class SDPrompter:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
class SDXLPrompter:
@@ -90,6 +115,8 @@ class SDXLPrompter:
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(
self,
@@ -98,8 +125,19 @@ class SDXLPrompter:
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
@@ -115,3 +153,22 @@ class SDXLPrompter:
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""

View File

@@ -1,37 +0,0 @@
from transformers import CLIPTokenizer
class SDTokenizer:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids

View File

@@ -1,45 +0,0 @@
from transformers import CLIPTokenizer
from .sd_tokenizer import SDTokenizer
class SDXLTokenizer(SDTokenizer):
def __init__(self):
super().__init__()
class SDXLTokenizer2:
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids

View File

@@ -3,9 +3,14 @@ import torch, math
class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012):
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
self.num_train_timesteps = num_train_timesteps
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
elif beta_schedule == "linear":
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
self.set_timesteps(10)
@@ -13,7 +18,7 @@ class EnhancedDDIMScheduler():
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = round(self.num_train_timesteps * denoising_strength) - 1
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = [max_timestep]
@@ -34,14 +39,14 @@ class EnhancedDDIMScheduler():
return prev_sample
def step(self, model_output, timestep, sample):
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[timestep]
timestep_id = self.timesteps.index(timestep)
if timestep_id + 1 < len(self.timesteps):
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = self.timesteps[timestep_id + 1]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
else:
alpha_prod_t_prev = 1.0
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)

19
environment.yml Normal file
View File

@@ -0,0 +1,19 @@
name: DiffSynthStudio
channels:
- pytorch
- nvidia
- defaults
dependencies:
- python=3.9.16
- pip=23.0.1
- cudatoolkit
- pytorch
- pip:
- transformers
- controlnet-aux==0.0.7
- streamlit
- streamlit-drawable-canvas
- imageio
- imageio[ffmpeg]
- safetensors
- einops

View File

@@ -0,0 +1,75 @@
from diffsynth import ModelManager, SDImagePipeline, ControlNetConfigUnit
import torch
# Download models
# `models/stable_diffusion/aingdiffusion_v12.safetensors`: [link](https://civitai.com/api/download/models/229575?type=Model&format=SafeTensor&size=full&fp=fp16)
# `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth)
# `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth)
# `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth)
# `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth)
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_textual_inversions("models/textual_inversion")
model_manager.load_models([
"models/stable_diffusion/aingdiffusion_v12.safetensors",
"models/ControlNet/control_v11f1e_sd15_tile.pth",
"models/ControlNet/control_v11p_sd15_lineart.pth"
])
pipe = SDImagePipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id="tile",
model_path=rf"models/ControlNet/control_v11f1e_sd15_tile.pth",
scale=0.5
),
ControlNetConfigUnit(
processor_id="lineart",
model_path=rf"models/ControlNet/control_v11p_sd15_lineart.pth",
scale=0.7
),
]
)
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
torch.manual_seed(0)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=7.5, clip_skip=1,
height=512, width=512, num_inference_steps=80,
)
image.save("512.jpg")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=7.5, clip_skip=1,
input_image=image.resize((1024, 1024)), controlnet_image=image.resize((1024, 1024)),
height=1024, width=1024, num_inference_steps=40, denoising_strength=0.7,
)
image.save("1024.jpg")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=7.5, clip_skip=1,
input_image=image.resize((2048, 2048)), controlnet_image=image.resize((2048, 2048)),
height=2048, width=2048, num_inference_steps=20, denoising_strength=0.7,
)
image.save("2048.jpg")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=7.5, clip_skip=1,
input_image=image.resize((4096, 4096)), controlnet_image=image.resize((4096, 4096)),
height=4096, width=4096, num_inference_steps=10, denoising_strength=0.5,
tiled=True, tile_size=128, tile_stride=64
)
image.save("4096.jpg")

View File

@@ -0,0 +1,56 @@
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
import torch
# Download models
# `models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors`: [link](https://civitai.com/api/download/models/266360?type=Model&format=SafeTensor&size=pruned&fp=fp16)
# `models/AnimateDiff/mm_sd_v15_v2.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt)
# `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth)
# `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth)
# `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth)
# `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth)
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_textual_inversions("models/textual_inversion")
model_manager.load_models([
"models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors",
"models/AnimateDiff/mm_sd_v15_v2.ckpt",
"models/ControlNet/control_v11p_sd15_lineart.pth",
"models/ControlNet/control_v11f1e_sd15_tile.pth",
])
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id="lineart",
model_path="models/ControlNet/control_v11p_sd15_lineart.pth",
scale=1.0
),
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/control_v11f1e_sd15_tile.pth",
scale=0.5
),
]
)
# Load video (we only use 16 frames in this example for testing)
video = VideoData(video_file="input_video.mp4", height=1536, width=1536)
input_video = [video[i] for i in range(16)]
# Toon shading
torch.manual_seed(0)
output_video = pipe(
prompt="best quality, perfect anime illustration, light, a girl is dancing, smile, solo",
negative_prompt="verybadimagenegative_v1.3",
cfg_scale=5, clip_skip=2,
controlnet_frames=input_video, num_frames=len(input_video),
num_inference_steps=10, height=1536, width=1536,
vram_limit_level=0,
)
# Save images and video
save_frames(output_video, "output_frames")
save_video(output_video, "output_video.mp4", fps=16)

View File

@@ -0,0 +1,34 @@
from diffsynth import ModelManager, SDXLImagePipeline
import torch
# Download models
# `models/stable_diffusion_xl/bluePencilXL_v200.safetensors`: [link](https://civitai.com/api/download/models/245614?type=Model&format=SafeTensor&size=pruned&fp=fp16)
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_models(["models/stable_diffusion_xl/bluePencilXL_v200.safetensors"])
pipe = SDXLImagePipeline.from_model_manager(model_manager)
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
torch.manual_seed(0)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=6,
height=1024, width=1024, num_inference_steps=60,
)
image.save("1024.jpg")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=6,
input_image=image.resize((2048, 2048)),
height=2048, width=2048, num_inference_steps=60, denoising_strength=0.5
)
image.save("2048.jpg")

31
examples/sdxl_turbo.py Normal file
View File

@@ -0,0 +1,31 @@
from diffsynth import ModelManager, SDXLImagePipeline
import torch
# Download models
# `models/stable_diffusion_xl_turbo/sd_xl_turbo_1.0_fp16.safetensors`: [link](https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/sd_xl_turbo_1.0_fp16.safetensors)
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_models(["models/stable_diffusion_xl_turbo/sd_xl_turbo_1.0_fp16.safetensors"])
pipe = SDXLImagePipeline.from_model_manager(model_manager)
# Text to image
torch.manual_seed(0)
image = pipe(
prompt="black car",
# Do not modify the following parameters!
cfg_scale=1, height=512, width=512, num_inference_steps=1, progress_bar_cmd=lambda x:x
)
image.save(f"black_car.jpg")
# Image to image
torch.manual_seed(0)
image = pipe(
prompt="red car",
input_image=image, denoising_strength=0.7,
# Do not modify the following parameters!
cfg_scale=1, height=512, width=512, num_inference_steps=1, progress_bar_cmd=lambda x:x
)
image.save(f"black_car_to_red_car.jpg")

View File

@@ -5,59 +5,61 @@ import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.prompts import SDXLPrompter, SDPrompter
from diffsynth.pipelines import SDXLPipeline, SDPipeline
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
torch.cuda.set_per_process_memory_fraction(0.999, 0)
config = {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
}
}
@st.cache_data
def load_model_list(folder):
def load_model_list(model_type):
folder = config[model_type]["model_folder"]
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors")]
file_list = sorted(file_list)
return file_list
def detect_model_path(sd_model_path, sdxl_model_path):
if sd_model_path != "None":
model_path = os.path.join("models/stable_diffusion", sd_model_path)
elif sdxl_model_path != "None":
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
else:
model_path = None
return model_path
def load_model(sd_model_path, sdxl_model_path):
if sd_model_path != "None":
model_path = os.path.join("models/stable_diffusion", sd_model_path)
model_manager = ModelManager()
model_manager.load_from_safetensors(model_path)
prompter = SDPrompter()
pipeline = SDPipeline()
elif sdxl_model_path != "None":
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
model_manager = ModelManager()
model_manager.load_from_safetensors(model_path)
prompter = SDXLPrompter()
pipeline = SDXLPipeline()
else:
return None, None, None, None
return model_path, model_manager, prompter, pipeline
def release_model():
if "model_manager" in st.session_state:
st.session_state["model_manager"].to("cpu")
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["prompter"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
def load_model(model_type, model_path):
model_manager = ModelManager()
model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.pipeline = pipeline
return model_manager, pipeline
def use_output_image_as_input():
# Search for input image
output_image_id = 0
@@ -102,68 +104,64 @@ def show_output_image(image):
column_input, column_output = st.columns(2)
# with column_input:
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
sd_model_list = ["None"] + load_model_list("models/stable_diffusion")
sd_model_path = st.selectbox(
"Stable Diffusion", sd_model_list
)
sdxl_model_list = ["None"] + load_model_list("models/stable_diffusion_xl")
sdxl_model_path = st.selectbox(
"Stable Diffusion XL", sdxl_model_list
)
model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"])
fixed_parameters = config[model_type]["fixed_parameters"]
model_path_list = ["None"] + load_model_list(model_type)
model_path = st.selectbox("Model path", model_path_list)
# Load the model
model_path = detect_model_path(sd_model_path, sdxl_model_path)
if model_path is None:
st.markdown("No models selected.")
if model_path == "None":
# No models are selected. Release VRAM.
st.markdown("No models are selected.")
release_model()
elif st.session_state.get("loaded_model_path", "") != model_path:
st.markdown(f"Using model at {model_path}.")
release_model()
model_path, model_manager, prompter, pipeline = load_model(sd_model_path, sdxl_model_path)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.prompter = prompter
st.session_state.pipeline = pipeline
else:
st.markdown(f"Using model at {model_path}.")
model_path, model_manager, prompter, pipeline = (
st.session_state.loaded_model_path,
st.session_state.model_manager,
st.session_state.prompter,
st.session_state.pipeline,
)
# A model is selected.
model_path = os.path.join(config[model_type]["model_folder"], model_path)
if st.session_state.get("loaded_model_path", "") != model_path:
# The loaded model is not the selected model. Reload it.
st.markdown(f"Using model at {model_path}.")
release_model()
model_manager, pipeline = load_model(model_type, model_path)
else:
# The loaded model is not the selected model. Fetch it from `st.session_state`.
st.markdown(f"Using model at {model_path}.")
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
# Show parameters
with st.expander("Prompt", expanded=True):
column_positive, column_negative = st.columns(2)
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
with st.expander("Classifier-free guidance", expanded=True):
use_cfg = st.checkbox("Use classifier-free guidance", value=True)
if use_cfg:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, step=0.1, value=7.5)
if "negative_prompt" in fixed_parameters:
negative_prompt = fixed_parameters["negative_prompt"]
else:
cfg_scale = 1.0
with st.expander("Inference steps", expanded=True):
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20, label_visibility="hidden")
with st.expander("Image size", expanded=True):
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
with st.expander("Seed", expanded=True):
negative_prompt = st.text_area("Negative prompt")
if "cfg_scale" in fixed_parameters:
cfg_scale = fixed_parameters["cfg_scale"]
else:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
with st.expander("Image", expanded=True):
if "num_inference_steps" in fixed_parameters:
num_inference_steps = fixed_parameters["num_inference_steps"]
else:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
if "height" in fixed_parameters:
height = fixed_parameters["height"]
else:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
if "width" in fixed_parameters:
width = fixed_parameters["width"]
else:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
num_images = st.number_input("Number of images", value=2)
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", value=0, label_visibility="hidden")
with st.expander("Number of images", expanded=True):
num_images = st.number_input("Number of images", value=4, label_visibility="hidden")
with st.expander("Tile (for high resolution)", expanded=True):
tiled = st.checkbox("Use tile", value=False)
tile_size = st.select_slider("Tile size", options=[64, 128], value=64)
tile_stride = st.select_slider("Tile stride", options=[8, 16, 32, 64], value=32)
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
# Other fixed parameters
denoising_strength = 1.0
repetition = 1
# Show input image
@@ -180,7 +178,7 @@ with column_input:
if upload_image is not None:
st.session_state["input_image"] = Image.open(upload_image)
elif create_white_board:
st.session_state["input_image"] = Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255)
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
else:
use_output_image_as_input()
@@ -202,6 +200,7 @@ with column_input:
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
with st.container(border=True):
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
with st.container(border=True):
input_width, input_height = input_image.size
canvas_result = st_canvas(
@@ -225,34 +224,32 @@ with column_output:
image_columns = st.columns(num_image_columns)
# Run
if (run_button or auto_update) and model_path is not None:
if (run_button or auto_update) and model_path != "None":
if not use_fixed_seed:
torch.manual_seed(np.random.randint(0, 10**9))
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
output_images = []
for image_id in range(num_images):
for image_id in range(num_images * repetition):
if use_fixed_seed:
torch.manual_seed(seed + image_id)
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
else:
denoising_strength = 1.0
torch.manual_seed(np.random.randint(0, 10**9))
if image_id >= num_images:
input_image = output_images[image_id - num_images]
with image_columns[image_id % num_image_columns]:
progress_bar = st.progress(0.0)
progress_bar_st = st.progress(0.0)
image = pipeline(
model_manager, prompter,
prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale,
num_inference_steps=num_inference_steps,
prompt, negative_prompt=negative_prompt,
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
height=height, width=width,
init_image=input_image, denoising_strength=denoising_strength,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
progress_bar_st=progress_bar
input_image=input_image, denoising_strength=denoising_strength,
progress_bar_st=progress_bar_st
)
output_images.append(image)
progress_bar.progress(1.0)
progress_bar_st.progress(1.0)
show_output_image(image)
st.session_state["output_images"] = output_images

4
pages/2_Video_Creator.py Normal file
View File

@@ -0,0 +1,4 @@
import streamlit as st
st.set_page_config(layout="wide")
st.markdown("# Coming soon")