mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
15
DiffSynth_Studio.py
Normal file
15
DiffSynth_Studio.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Set web page format
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
# Diasble virtual VRAM on windows system
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||
|
||||
|
||||
st.markdown("""
|
||||
# DiffSynth Studio
|
||||
|
||||
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
||||
|
||||
Welcome to DiffSynth Studio.
|
||||
""")
|
||||
53
README-zh.md
53
README-zh.md
@@ -1,53 +0,0 @@
|
||||
# DiffSynth Studio
|
||||
|
||||
## 介绍
|
||||
|
||||
DiffSynth 是一个全新的 Diffusion 引擎,我们重构了 Text Encoder、UNet、VAE 等架构,保持与开源社区模型兼容性的同时,提升了计算性能。目前这个版本仅仅是一个初始版本,实现了文生图和图生图功能,支持 SD 和 SDXL 架构。未来我们计划基于这个全新的代码库开发更多有趣的功能。
|
||||
|
||||
## 安装
|
||||
|
||||
如果你只想在 Python 代码层面调用 DiffSynth Studio,你只需要安装 `torch`(深度学习框架)和 `transformers`(仅用于实现分词器)。
|
||||
|
||||
```
|
||||
pip install torch transformers
|
||||
```
|
||||
|
||||
如果你想使用 UI,还需要额外安装 `streamlit`(一个 webui 框架)和 `streamlit-drawable-canvas`(用于图生图画板)。
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
通过 Python 代码调用
|
||||
|
||||
```python
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.prompts import SDPrompter, SDXLPrompter
|
||||
from diffsynth.pipelines import SDPipeline, SDXLPipeline
|
||||
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
|
||||
prompter = SDPrompter()
|
||||
pipe = SDPipeline()
|
||||
|
||||
prompt = "a girl"
|
||||
negative_prompt = ""
|
||||
|
||||
image = pipe(
|
||||
model_manager, prompter,
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=20, height=512, width=512,
|
||||
)
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
如果需要用 SDXL 架构模型,请把 `SDPrompter`、`SDPipeline` 换成 `SDXLPrompter`, `SDXLPipeline`。
|
||||
|
||||
当然,你也可以使用我们提供的 UI,但请注意,我们的 UI 程序很简单,且未来可能会大幅改变。
|
||||
|
||||
```
|
||||
python -m streamlit run Diffsynth_Studio.py
|
||||
```
|
||||
78
README.md
78
README.md
@@ -1,53 +1,59 @@
|
||||
# DiffSynth Studio
|
||||
|
||||
## 介绍
|
||||
## Introduction
|
||||
|
||||
DiffSynth is a new Diffusion engine. We have restructured architectures like Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting text-to-image and image-to-image functionalities, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
|
||||
DiffSynth is a new Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
|
||||
|
||||
## 安装
|
||||
## Installation
|
||||
|
||||
If you only want to use DiffSynth Studio at the Python code level, you just need to install torch (a deep learning framework) and transformers (only used for implementing a tokenizer).
|
||||
Create Python environment:
|
||||
|
||||
```
|
||||
pip install torch transformers
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
|
||||
If you wish to use the UI, you'll also need to additionally install `streamlit` (a web UI framework) and `streamlit-drawable-canvas` (used for the image-to-image canvas).
|
||||
Enter the Python environment:
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
conda activate DiffSynthStudio
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
Use DiffSynth Studio in Python
|
||||
|
||||
```python
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.prompts import SDPrompter, SDXLPrompter
|
||||
from diffsynth.pipelines import SDPipeline, SDXLPipeline
|
||||
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
|
||||
prompter = SDPrompter()
|
||||
pipe = SDPipeline()
|
||||
|
||||
prompt = "a girl"
|
||||
negative_prompt = ""
|
||||
|
||||
image = pipe(
|
||||
model_manager, prompter,
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=20, height=512, width=512,
|
||||
)
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
If you want to use SDXL architecture models, replace `SDPrompter` and `SDPipeline` with `SDXLPrompter` and `SDXLPipeline`, respectively.
|
||||
|
||||
Of course, you can also use the UI we provide. The UI is simple but may be changed in the future.
|
||||
## Usage (in WebUI)
|
||||
|
||||
```
|
||||
python -m streamlit run Diffsynth_Studio.py
|
||||
```
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
||||
|
||||
## Usage (in Python code)
|
||||
|
||||
### Example 1: Stable Diffusion
|
||||
|
||||
We can generate images with very high resolution. Please see `examples/sd_text_to_image.py` for more details.
|
||||
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Example 2: Stable Diffusion XL
|
||||
|
||||
Generate images with Stable Diffusion XL. Please see `examples/sdxl_text_to_image.py` for more details.
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example 3: Stable Diffusion XL Turbo
|
||||
|
||||
Generate images with Stable Diffusion XL Turbo. You can see `examples/sdxl_turbo.py` for more details, but we highly recommend you to use it in the WebUI.
|
||||
|
||||
|"black car"|"red car"|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example 4: Toon Shading
|
||||
|
||||
A very interesting example. Please see `examples/sd_toon_shading.py` for more details.
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/53532f0e-39b1-4791-b920-c975d52ec24a
|
||||
|
||||
6
diffsynth/__init__.py
Normal file
6
diffsynth/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .prompts import *
|
||||
from .schedulers import *
|
||||
from .pipelines import *
|
||||
from .controlnets import *
|
||||
2
diffsynth/controlnets/__init__.py
Normal file
2
diffsynth/controlnets/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||
from .processors import Annotator
|
||||
55
diffsynth/controlnets/controlnet_unit.py
Normal file
55
diffsynth/controlnets/controlnet_unit.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from .processors import Processor_id
|
||||
|
||||
|
||||
class ControlNetConfigUnit:
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||
self.processor_id = processor_id
|
||||
self.model_path = model_path
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class ControlNetUnit:
|
||||
def __init__(self, processor, model, scale=1.0):
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class MultiControlNetManager:
|
||||
def __init__(self, controlnet_units=[]):
|
||||
self.processors = [unit.processor for unit in controlnet_units]
|
||||
self.models = [unit.model for unit in controlnet_units]
|
||||
self.scales = [unit.scale for unit in controlnet_units]
|
||||
|
||||
def process_image(self, image, return_image=False):
|
||||
processed_image = [
|
||||
processor(image)
|
||||
for processor in self.processors
|
||||
]
|
||||
if return_image:
|
||||
return processed_image
|
||||
processed_image = torch.concat([
|
||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
||||
for image_ in processed_image
|
||||
], dim=0)
|
||||
return processed_image
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditionings,
|
||||
tiled=False, tile_size=64, tile_stride=32
|
||||
):
|
||||
res_stack = None
|
||||
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
||||
res_stack_ = model(
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
51
diffsynth/controlnets/processors.py
Normal file
51
diffsynth/controlnets/processors.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path)
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
def __call__(self, image):
|
||||
width, height = image.size
|
||||
if self.processor_id == "openpose":
|
||||
kwargs = {
|
||||
"include_body": True,
|
||||
"include_hand": True,
|
||||
"include_face": True
|
||||
}
|
||||
else:
|
||||
kwargs = {}
|
||||
if self.processor is not None:
|
||||
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
||||
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
||||
image = image.resize((width, height))
|
||||
return image
|
||||
|
||||
1
diffsynth/data/__init__.py
Normal file
1
diffsynth/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .video import VideoData, save_video, save_frames
|
||||
148
diffsynth/data/video.py
Normal file
148
diffsynth/data/video.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
def __init__(self, file_name):
|
||||
self.reader = imageio.get_reader(file_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.reader.count_frames()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
||||
|
||||
def __del__(self):
|
||||
self.reader.close()
|
||||
|
||||
|
||||
def split_file_name(file_name):
|
||||
result = []
|
||||
number = -1
|
||||
for i in file_name:
|
||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
||||
if number == -1:
|
||||
number = 0
|
||||
number = number*10 + ord(i) - ord("0")
|
||||
else:
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
number = -1
|
||||
result.append(i)
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
|
||||
def search_for_images(folder):
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
||||
file_list = [i[1] for i in sorted(file_list)]
|
||||
file_list = [os.path.join(folder, i) for i in file_list]
|
||||
return file_list
|
||||
|
||||
|
||||
class LowMemoryImageFolder:
|
||||
def __init__(self, folder, file_list=None):
|
||||
if file_list is None:
|
||||
self.file_list = search_for_images(folder)
|
||||
else:
|
||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return Image.open(self.file_list[item]).convert("RGB")
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
def crop_and_resize(image, height, width):
|
||||
image = np.array(image)
|
||||
image_height, image_width, _ = image.shape
|
||||
if image_height / image_width < height / width:
|
||||
croped_width = int(image_height / height * width)
|
||||
left = (image_width - croped_width) // 2
|
||||
image = image[:, left: left+croped_width]
|
||||
image = Image.fromarray(image).resize((width, height))
|
||||
else:
|
||||
croped_height = int(image_width / width * height)
|
||||
left = (image_height - croped_height) // 2
|
||||
image = image[left: left+croped_height, :]
|
||||
image = Image.fromarray(image).resize((width, height))
|
||||
return image
|
||||
|
||||
|
||||
class VideoData:
|
||||
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
||||
if video_file is not None:
|
||||
self.data_type = "video"
|
||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
||||
elif image_folder is not None:
|
||||
self.data_type = "images"
|
||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
||||
else:
|
||||
raise ValueError("Cannot open video or image folder")
|
||||
self.length = None
|
||||
self.set_shape(height, width)
|
||||
|
||||
def raw_data(self):
|
||||
frames = []
|
||||
for i in range(self.__len__()):
|
||||
frames.append(self.__getitem__(i))
|
||||
return frames
|
||||
|
||||
def set_length(self, length):
|
||||
self.length = length
|
||||
|
||||
def set_shape(self, height, width):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
def __len__(self):
|
||||
if self.length is None:
|
||||
return len(self.data)
|
||||
else:
|
||||
return self.length
|
||||
|
||||
def shape(self):
|
||||
if self.height is not None and self.width is not None:
|
||||
return self.height, self.width
|
||||
else:
|
||||
height, width, _ = self.__getitem__(0).shape
|
||||
return height, width
|
||||
|
||||
def __getitem__(self, item):
|
||||
frame = self.data.__getitem__(item)
|
||||
width, height = frame.size
|
||||
if self.height is not None and self.width is not None:
|
||||
if self.height != height or self.width != width:
|
||||
frame = crop_and_resize(frame, self.height, self.width)
|
||||
return frame
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
def save_images(self, folder):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
||||
frame = self.__getitem__(i)
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
def save_frames(frames, save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
@@ -11,21 +11,44 @@ from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_type=torch.float16, device="cuda"):
|
||||
self.torch_type = torch_type
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = {}
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
return True
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None):
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
"unet": SDUNet,
|
||||
@@ -36,18 +59,30 @@ class ModelManager:
|
||||
if components is None:
|
||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_type).to(self.device)
|
||||
if component == "text_encoder":
|
||||
# Add additional token embeddings to text encoder
|
||||
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||
for keyword in self.textual_inversion_dict:
|
||||
_, embeddings = self.textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None):
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDXLTextEncoder,
|
||||
"text_encoder_2": SDXLTextEncoder2,
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
@@ -60,18 +95,97 @@ class ModelManager:
|
||||
# I do not know how to solve this problem.
|
||||
self.model[component].to(torch.float32).to(self.device)
|
||||
else:
|
||||
self.model[component].to(self.torch_type).to(self.device)
|
||||
|
||||
def load_from_safetensors(self, file_path, components=None):
|
||||
state_dict = load_state_dict_from_safetensors(file_path)
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components)
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_controlnet(self, state_dict, file_path=""):
|
||||
component = "controlnet"
|
||||
if component not in self.model:
|
||||
self.model[component] = []
|
||||
self.model_path[component] = []
|
||||
model = SDControlNet()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component].append(model)
|
||||
self.model_path[component].append(file_path)
|
||||
|
||||
def load_animatediff(self, state_dict, file_path=""):
|
||||
component = "motion_modules"
|
||||
model = SDMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += self.search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
def load_textual_inversions(self, folder):
|
||||
# Store additional tokens here
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
|
||||
# Search for embeddings
|
||||
for embeddings in self.search_for_embeddings(state_dict):
|
||||
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None):
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
elif self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
self.model[component].to(device)
|
||||
if isinstance(self.model[component], list):
|
||||
for model in self.model[component]:
|
||||
model.to(device)
|
||||
else:
|
||||
self.model[component].to(device)
|
||||
|
||||
def get_model_with_model_path(self, model_path):
|
||||
for component in self.model_path:
|
||||
if isinstance(self.model_path[component], str):
|
||||
if os.path.samefile(self.model_path[component], model_path):
|
||||
return self.model[component]
|
||||
elif isinstance(self.model_path[component], list):
|
||||
for i, model_path_ in enumerate(self.model_path[component]):
|
||||
if os.path.samefile(model_path_, model_path):
|
||||
return self.model[component][i]
|
||||
raise ValueError(f"Please load model {model_path} before you use it.")
|
||||
|
||||
def __getattr__(self, __name):
|
||||
if __name in self.model:
|
||||
@@ -80,16 +194,28 @@ class ModelManager:
|
||||
return super.__getattribute__(__name)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path):
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path):
|
||||
return torch.load(file_path, map_location="cpu")
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
|
||||
@@ -1,4 +1,15 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def low_version_attention(query, key, value, attn_bias=None):
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ value
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
@@ -15,7 +26,7 @@ class Attention(torch.nn.Module):
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
@@ -30,9 +41,36 @@ class Attention(torch.nn.Module):
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
|
||||
if attn_mask is not None:
|
||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||
else:
|
||||
import xformers.ops as xops
|
||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask)
|
||||
584
diffsynth/models/sd_controlnet.py
Normal file
584
diffsynth/models/sd_controlnet.py
Normal file
@@ -0,0 +1,584 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class ControlNetConditioningLayer(torch.nn.Module):
|
||||
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList([])
|
||||
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
||||
self.blocks.append(torch.nn.SiLU())
|
||||
for i in range(1, len(channels) - 2):
|
||||
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
||||
self.blocks.append(torch.nn.SiLU())
|
||||
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
||||
self.blocks.append(torch.nn.SiLU())
|
||||
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, conditioning):
|
||||
for block in self.blocks:
|
||||
conditioning = block(conditioning)
|
||||
return conditioning
|
||||
|
||||
|
||||
class SDControlNet(torch.nn.Module):
|
||||
def __init__(self, global_pool=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(320, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
|
||||
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(320, 320, 1280),
|
||||
AttentionBlock(8, 40, 320, 1, 768),
|
||||
PushBlock(),
|
||||
ResnetBlock(320, 320, 1280),
|
||||
AttentionBlock(8, 40, 320, 1, 768),
|
||||
PushBlock(),
|
||||
DownSampler(320),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(320, 640, 1280),
|
||||
AttentionBlock(8, 80, 640, 1, 768),
|
||||
PushBlock(),
|
||||
ResnetBlock(640, 640, 1280),
|
||||
AttentionBlock(8, 80, 640, 1, 768),
|
||||
PushBlock(),
|
||||
DownSampler(640),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(640, 1280, 1280),
|
||||
AttentionBlock(8, 160, 1280, 1, 768),
|
||||
PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(8, 160, 1280, 1, 768),
|
||||
PushBlock(),
|
||||
DownSampler(1280),
|
||||
PushBlock(),
|
||||
# DownBlock2D
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock(),
|
||||
# UNetMidBlock2DCrossAttn
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(8, 160, 1280, 1, 768),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock()
|
||||
])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
])
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
# pool
|
||||
if self.global_pool:
|
||||
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDControlNetStateDictConverter()
|
||||
|
||||
|
||||
class SDControlNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
||||
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
||||
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
||||
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
||||
]
|
||||
|
||||
# controlnet_rename_dict
|
||||
controlnet_rename_dict = {
|
||||
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
||||
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
||||
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
||||
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
||||
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
||||
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
||||
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
||||
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
||||
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
||||
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
||||
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
||||
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
||||
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
||||
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
||||
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
||||
}
|
||||
|
||||
# Rename each parameter
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif name in controlnet_rename_dict:
|
||||
names = controlnet_rename_dict[name].split(".")
|
||||
elif names[0] == "controlnet_down_blocks":
|
||||
names[0] = "controlnet_blocks"
|
||||
elif names[0] == "controlnet_mid_block":
|
||||
names = ["controlnet_blocks", "12", names[-1]]
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
||||
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
||||
if names[0] == "mid_block":
|
||||
names.insert(1, "0")
|
||||
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
names = ["blocks", str(block_id[block_type])] + names[4:]
|
||||
if "ff" in names:
|
||||
ff_index = names.index("ff")
|
||||
component = ".".join(names[ff_index:ff_index+3])
|
||||
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
||||
names = names[:ff_index] + [component] + names[ff_index+3:]
|
||||
if "to_out" in names:
|
||||
names.pop(names.index("to_out") + 1)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameters: {name}")
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
if rename_dict[name] in [
|
||||
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
||||
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
||||
]:
|
||||
continue
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
||||
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
||||
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
||||
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
||||
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
||||
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
||||
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
||||
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
||||
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
||||
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
||||
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
||||
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
||||
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
||||
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
||||
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
||||
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
||||
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
||||
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
||||
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
||||
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
||||
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
||||
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
||||
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
||||
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
||||
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
||||
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
||||
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
||||
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
||||
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
||||
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
||||
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
||||
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
||||
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
||||
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
||||
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
||||
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
||||
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
||||
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
||||
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
||||
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
||||
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
||||
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
||||
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
||||
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
||||
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
||||
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
||||
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
||||
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
||||
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
||||
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
||||
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
||||
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
||||
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
||||
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
||||
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
||||
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
||||
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
||||
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
||||
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
||||
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
||||
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
||||
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
||||
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
||||
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
||||
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
||||
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
||||
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
||||
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
||||
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
||||
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
||||
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
||||
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
||||
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
||||
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
||||
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
||||
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
||||
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
||||
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
||||
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
||||
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
||||
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
||||
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
||||
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
||||
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
||||
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
||||
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
||||
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
||||
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
||||
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
||||
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
||||
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
||||
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
||||
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
||||
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
||||
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
||||
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
||||
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
||||
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
||||
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
||||
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
||||
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
||||
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
||||
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
||||
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
||||
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
||||
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
||||
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
||||
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
||||
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
||||
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
||||
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
||||
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
||||
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
||||
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
||||
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
||||
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
||||
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
||||
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
||||
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
||||
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
||||
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
||||
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
||||
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
||||
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
||||
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
||||
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
||||
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
||||
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
||||
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
||||
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
||||
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
||||
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
||||
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
||||
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
||||
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
||||
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
||||
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
||||
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
||||
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
||||
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
||||
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
||||
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
||||
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
||||
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
||||
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
||||
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
||||
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
||||
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
||||
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
||||
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
||||
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
||||
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
||||
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
||||
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
||||
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
||||
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
||||
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
||||
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
||||
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
||||
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
||||
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
||||
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
||||
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
||||
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
||||
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
||||
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
||||
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
||||
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
||||
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
||||
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
||||
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
||||
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
||||
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
||||
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
||||
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
||||
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
||||
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
||||
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
||||
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
||||
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
||||
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
198
diffsynth/models/sd_motion.py
Normal file
198
diffsynth/models/sd_motion.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from .sd_unet import SDUNet, Attention, GEGLU
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class TemporalTransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-Attn
|
||||
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
# 2. Cross-Attn
|
||||
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.act_fn = GEGLU(dim, dim * 4)
|
||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||
|
||||
|
||||
def forward(self, hidden_states, batch_size=1):
|
||||
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.act_fn(norm_hidden_states)
|
||||
ff_output = self.ff(ff_output)
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TemporalBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
TemporalTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SDMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
])
|
||||
self.call_block_id = {
|
||||
1: 0,
|
||||
4: 1,
|
||||
9: 2,
|
||||
12: 3,
|
||||
17: 4,
|
||||
20: 5,
|
||||
24: 6,
|
||||
26: 7,
|
||||
29: 8,
|
||||
32: 9,
|
||||
34: 10,
|
||||
36: 11,
|
||||
40: 12,
|
||||
43: 13,
|
||||
46: 14,
|
||||
50: 15,
|
||||
53: 16,
|
||||
56: 17,
|
||||
60: 18,
|
||||
63: 19,
|
||||
66: 20
|
||||
}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
||||
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
||||
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
||||
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
||||
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
||||
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
||||
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
||||
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
||||
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
||||
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
||||
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
||||
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
||||
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
||||
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
||||
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
||||
state_dict_ = {}
|
||||
last_prefix, module_id = "", -1
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
prefix_index = names.index("temporal_transformer") + 1
|
||||
prefix = ".".join(names[:prefix_index])
|
||||
if prefix != last_prefix:
|
||||
last_prefix = prefix
|
||||
module_id += 1
|
||||
middle_name = ".".join(names[prefix_index:-1])
|
||||
suffix = names[-1]
|
||||
if "pos_encoder" in names:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -279,27 +279,19 @@ class SDUNet(torch.nn.Module):
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, **kwargs):
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if tiled:
|
||||
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
|
||||
block, hidden_states, time_emb, text_emb, res_stack,
|
||||
height, width, tile_size, tile_stride
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -308,23 +300,6 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
|
||||
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
|
||||
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
|
||||
hidden_states = Tiler()(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
inter_device=hidden_states.device,
|
||||
inter_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
@@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
@@ -38,11 +38,13 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class Tiler(torch.nn.Module):
|
||||
@@ -70,6 +71,106 @@ class Tiler(torch.nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class TileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def mask(self, height, width, border_width):
|
||||
# Create a mask with shape (height, width).
|
||||
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / border_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
|
||||
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||
batch_size, channel, _, _ = model_input.shape
|
||||
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||
unfold_operator = torch.nn.Unfold(
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
model_input = unfold_operator(model_input)
|
||||
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||
# Call y=forward_fn(x) for each tile
|
||||
tile_num = model_input.shape[-1]
|
||||
model_output_stack = []
|
||||
|
||||
for tile_id in range(0, tile_num, tile_batch_size):
|
||||
|
||||
# process input
|
||||
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||
|
||||
# process output
|
||||
y = forward_fn(x)
|
||||
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||
model_output_stack.append(y)
|
||||
|
||||
model_output = torch.concat(model_output_stack, dim=-1)
|
||||
return model_output
|
||||
|
||||
|
||||
def io_scale(self, model_output, tile_size):
|
||||
# Determine the size modification happend in forward_fn
|
||||
# We only consider the same scale on height and width.
|
||||
io_scale = model_output.shape[2] / tile_size
|
||||
return io_scale
|
||||
|
||||
|
||||
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||
# The reversed function of tile
|
||||
mask = self.mask(tile_size, tile_size, border_width)
|
||||
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||
model_output = model_output * mask
|
||||
|
||||
fold_operator = torch.nn.Fold(
|
||||
output_size=(height, width),
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||
height, width = model_input.shape[2], model_input.shape[3]
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
|
||||
# tile
|
||||
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||
|
||||
# inference
|
||||
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||
|
||||
# resize
|
||||
io_scale = self.io_scale(model_output, tile_size)
|
||||
height, width = int(height*io_scale), int(width*io_scale)
|
||||
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||
border_width = int(border_width*io_scale)
|
||||
|
||||
# untile
|
||||
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||
|
||||
# Done!
|
||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||
return model_output
|
||||
@@ -1,2 +1,3 @@
|
||||
from .stable_diffusion import SDPipeline
|
||||
from .stable_diffusion_xl import SDXLPipeline
|
||||
from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline
|
||||
|
||||
113
diffsynth/pipelines/dancer.py
Normal file
113
diffsynth/pipelines/dancer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
from ..models import SDUNet, SDMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock
|
||||
from ..models.tiler import TileWorker
|
||||
from ..controlnets import MultiControlNetManager
|
||||
|
||||
|
||||
def lets_dance(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 1. ControlNet
|
||||
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
||||
# I leave it here because I intend to do something interesting on the ControlNets.
|
||||
controlnet_insert_block_id = 30
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
res_stacks = []
|
||||
# process controlnet frames with batch
|
||||
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
||||
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
||||
res_stack = controlnet(
|
||||
sample[batch_id: batch_id_],
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_],
|
||||
controlnet_frames[:, batch_id: batch_id_],
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if vram_limit_level >= 1:
|
||||
res_stack = [res.cpu() for res in res_stack]
|
||||
res_stacks.append(res_stack)
|
||||
# concat the residual
|
||||
additional_res_stack = []
|
||||
for i in range(len(res_stacks[0])):
|
||||
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
||||
additional_res_stack.append(res)
|
||||
else:
|
||||
additional_res_stack = None
|
||||
|
||||
# 2. time
|
||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = unet.time_embedding(time_emb)
|
||||
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
# 4.1 UNet
|
||||
if isinstance(block, PushBlock):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].cpu()
|
||||
elif isinstance(block, PopBlock):
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].to(device)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
else:
|
||||
hidden_states_input = hidden_states
|
||||
hidden_states_output = []
|
||||
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
||||
if tiled:
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
|
||||
hidden_states_output.append(hidden_states)
|
||||
hidden_states = torch.concat(hidden_states_output, dim=0)
|
||||
# 4.2 AnimateDiff
|
||||
if motion_modules is not None:
|
||||
if block_id in motion_modules.call_block_id:
|
||||
motion_module_id = motion_modules.call_block_id[block_id]
|
||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
batch_size=1
|
||||
)
|
||||
# 4.3 ControlNet
|
||||
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
||||
hidden_states += additional_res_stack.pop().to(device)
|
||||
if vram_limit_level>=1:
|
||||
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
else:
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
hidden_states = unet.conv_act(hidden_states)
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -1,32 +1,90 @@
|
||||
from ..models import ModelManager
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDPipeline(torch.nn.Module):
|
||||
class SDImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
self.vae_decoder: SDVAEDecoder = None
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
prompter: SDPrompter,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
init_image=None,
|
||||
input_image=None,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
@@ -37,39 +95,54 @@ class SDPipeline(torch.nn.Module):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
prompt_emb = prompter.encode_prompt(model_manager.text_encoder, prompt, clip_skip=clip_skip, device=model_manager.device)
|
||||
negative_prompt_emb = prompter.encode_prompt(model_manager.text_encoder, negative_prompt, clip_skip=clip_skip, device=model_manager.device)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if init_image is not None:
|
||||
image = self.preprocess_image(init_image).to(device=model_manager.device, dtype=model_manager.torch_type)
|
||||
latents = model_manager.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_cond = model_manager.unet(latents, timestep, prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise_pred_uncond = model_manager.unet(latents, timestep, negative_prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_pred_posi = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
device=self.device, vram_limit_level=0
|
||||
)
|
||||
noise_pred_nega = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
device=self.device, vram_limit_level=0
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
|
||||
217
diffsynth/pipelines/stable_diffusion_video.py
Normal file
217
diffsynth/pipelines/stable_diffusion_video.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
num_frames = sample.shape[0]
|
||||
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
||||
|
||||
for batch_id in range(0, num_frames, animatediff_stride):
|
||||
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
||||
|
||||
# process this batch
|
||||
hidden_states_batch = lets_dance(
|
||||
unet, motion_modules, controlnet,
|
||||
sample[batch_id: batch_id_].to(device),
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_].to(device),
|
||||
controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, device=device, vram_limit_level=vram_limit_level
|
||||
).cpu()
|
||||
|
||||
# update hidden_states
|
||||
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
||||
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1) / 2), 1e-2)
|
||||
hidden_states, num = hidden_states_output[i]
|
||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||
hidden_states_output[i] = (hidden_states, num + 1)
|
||||
|
||||
# output
|
||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||
self.prompter = SDPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: SDUNet = None
|
||||
self.vae_decoder: SDVAEDecoder = None
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.motion_modules: SDMotionModel = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||
if "motion_modules" in model_manager.model:
|
||||
self.motion_modules = model_manager.motion_modules
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
use_animatediff="motion_modules" in model_manager.model
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||
latents.append(latent)
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
vram_limit_level=0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
|
||||
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
||||
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
output_frames = self.decode_images(latents)
|
||||
|
||||
return output_frames
|
||||
@@ -1,4 +1,5 @@
|
||||
from ..models import ModelManager
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
@@ -7,29 +8,77 @@ from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLPipeline(torch.nn.Module):
|
||||
class SDXLImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# TODO: SDXL ControlNet
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
prompter: SDXLPrompter,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
init_image=None,
|
||||
input_image=None,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
refining_strength=0.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
@@ -39,76 +88,62 @@ class SDXLPipeline(torch.nn.Module):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
add_text_embeds, prompt_emb = prompter.encode_prompt(
|
||||
model_manager.text_encoder,
|
||||
model_manager.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=model_manager.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt(
|
||||
model_manager.text_encoder,
|
||||
model_manager.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=model_manager.device
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if init_image is not None:
|
||||
image = self.preprocess_image(init_image).to(
|
||||
device=model_manager.device, dtype=model_manager.torch_type
|
||||
)
|
||||
latents = model_manager.vae_encoder(
|
||||
image.to(torch.float32),
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise = torch.randn(
|
||||
(1, 4, height//8, width//8),
|
||||
device=model_manager.device, dtype=model_manager.torch_type
|
||||
)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents.to(model_manager.torch_type),
|
||||
noise,
|
||||
timestep=self.scheduler.timesteps[0]
|
||||
)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device)
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
if timestep >= 1000 * refining_strength:
|
||||
denoising_model = model_manager.unet
|
||||
else:
|
||||
denoising_model = model_manager.refiner
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_cond = denoising_model(
|
||||
latents, timestep, prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
|
||||
noise_pred_posi = self.unet(
|
||||
latents, timestep, prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise_pred_uncond = denoising_model(
|
||||
latents, timestep, negative_prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds,
|
||||
noise_pred_nega = self.unet(
|
||||
latents, timestep, prompt_emb_nega,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = denoising_model(
|
||||
latents, timestep, prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
|
||||
noise_pred = self.unet(
|
||||
latents, timestep, prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
@@ -118,9 +153,6 @@ class SDXLPipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
latents = latents.to(torch.float32)
|
||||
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
@@ -36,49 +35,75 @@ def tokenize_long_prompt(tokenizer, prompt):
|
||||
return input_ids
|
||||
|
||||
|
||||
def load_textual_inversion(prompt):
|
||||
# TODO: This module is not enabled now.
|
||||
textual_inversion_files = os.listdir("models/textual_inversion")
|
||||
embeddings_768 = []
|
||||
embeddings_1280 = []
|
||||
for file_name in textual_inversion_files:
|
||||
if not file_name.endswith(".safetensors"):
|
||||
continue
|
||||
keyword = file_name[:-len(".safetensors")]
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, "")
|
||||
with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
embedding = f.get_tensor(k).to(torch.float32)
|
||||
if embedding.shape[-1] == 768:
|
||||
embeddings_768.append(embedding)
|
||||
elif embedding.shape[-1] == 1280:
|
||||
embeddings_1280.append(embedding)
|
||||
|
||||
if len(embeddings_768)==0:
|
||||
embeddings_768 = torch.zeros((0, 768))
|
||||
else:
|
||||
embeddings_768 = torch.concat(embeddings_768, dim=0)
|
||||
|
||||
if len(embeddings_1280)==0:
|
||||
embeddings_1280 = torch.zeros((0, 1280))
|
||||
else:
|
||||
embeddings_1280 = torch.concat(embeddings_1280, dim=0)
|
||||
|
||||
return prompt, embeddings_768, embeddings_1280
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class SDPrompter:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.keyword_dict = {}
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
# Textual Inversion
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
|
||||
# Beautiful Prompt
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
|
||||
class SDXLPrompter:
|
||||
@@ -90,6 +115,8 @@ class SDXLPrompter:
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.keyword_dict = {}
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -98,8 +125,19 @@ class SDXLPrompter:
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
# Textual Inversion
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
|
||||
# Beautiful Prompt
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
@@ -115,3 +153,22 @@ class SDXLPrompter:
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def __call__(self, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = self.tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
self.tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore self.tokenizer.model_max_length
|
||||
self.tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
@@ -1,45 +0,0 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from .sd_tokenizer import SDTokenizer
|
||||
|
||||
|
||||
class SDXLTokenizer(SDTokenizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class SDXLTokenizer2:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def __call__(self, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = self.tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
self.tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore self.tokenizer.model_max_length
|
||||
self.tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
@@ -3,9 +3,14 @@ import torch, math
|
||||
|
||||
class EnhancedDDIMScheduler():
|
||||
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012):
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||
if beta_schedule == "scaled_linear":
|
||||
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||
elif beta_schedule == "linear":
|
||||
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
|
||||
self.set_timesteps(10)
|
||||
|
||||
@@ -13,7 +18,7 @@ class EnhancedDDIMScheduler():
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
||||
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||
# but I think this implementation is more reasonable in theory.
|
||||
max_timestep = round(self.num_train_timesteps * denoising_strength) - 1
|
||||
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||
if num_inference_steps == 1:
|
||||
self.timesteps = [max_timestep]
|
||||
@@ -34,14 +39,14 @@ class EnhancedDDIMScheduler():
|
||||
return prev_sample
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample):
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
timestep_id = self.timesteps.index(timestep)
|
||||
if timestep_id + 1 < len(self.timesteps):
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
alpha_prod_t_prev = 1.0
|
||||
else:
|
||||
timestep_prev = self.timesteps[timestep_id + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||
else:
|
||||
alpha_prod_t_prev = 1.0
|
||||
|
||||
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||
|
||||
|
||||
19
environment.yml
Normal file
19
environment.yml
Normal file
@@ -0,0 +1,19 @@
|
||||
name: DiffSynthStudio
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.16
|
||||
- pip=23.0.1
|
||||
- cudatoolkit
|
||||
- pytorch
|
||||
- pip:
|
||||
- transformers
|
||||
- controlnet-aux==0.0.7
|
||||
- streamlit
|
||||
- streamlit-drawable-canvas
|
||||
- imageio
|
||||
- imageio[ffmpeg]
|
||||
- safetensors
|
||||
- einops
|
||||
75
examples/sd_text_to_image.py
Normal file
75
examples/sd_text_to_image.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from diffsynth import ModelManager, SDImagePipeline, ControlNetConfigUnit
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
# `models/stable_diffusion/aingdiffusion_v12.safetensors`: [link](https://civitai.com/api/download/models/229575?type=Model&format=SafeTensor&size=full&fp=fp16)
|
||||
# `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth)
|
||||
# `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth)
|
||||
# `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth)
|
||||
# `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth)
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_textual_inversions("models/textual_inversion")
|
||||
model_manager.load_models([
|
||||
"models/stable_diffusion/aingdiffusion_v12.safetensors",
|
||||
"models/ControlNet/control_v11f1e_sd15_tile.pth",
|
||||
"models/ControlNet/control_v11p_sd15_lineart.pth"
|
||||
])
|
||||
pipe = SDImagePipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path=rf"models/ControlNet/control_v11f1e_sd15_tile.pth",
|
||||
scale=0.5
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="lineart",
|
||||
model_path=rf"models/ControlNet/control_v11p_sd15_lineart.pth",
|
||||
scale=0.7
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=7.5, clip_skip=1,
|
||||
height=512, width=512, num_inference_steps=80,
|
||||
)
|
||||
image.save("512.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=7.5, clip_skip=1,
|
||||
input_image=image.resize((1024, 1024)), controlnet_image=image.resize((1024, 1024)),
|
||||
height=1024, width=1024, num_inference_steps=40, denoising_strength=0.7,
|
||||
)
|
||||
image.save("1024.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=7.5, clip_skip=1,
|
||||
input_image=image.resize((2048, 2048)), controlnet_image=image.resize((2048, 2048)),
|
||||
height=2048, width=2048, num_inference_steps=20, denoising_strength=0.7,
|
||||
)
|
||||
image.save("2048.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=7.5, clip_skip=1,
|
||||
input_image=image.resize((4096, 4096)), controlnet_image=image.resize((4096, 4096)),
|
||||
height=4096, width=4096, num_inference_steps=10, denoising_strength=0.5,
|
||||
tiled=True, tile_size=128, tile_stride=64
|
||||
)
|
||||
image.save("4096.jpg")
|
||||
56
examples/sd_toon_shading.py
Normal file
56
examples/sd_toon_shading.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
# `models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors`: [link](https://civitai.com/api/download/models/266360?type=Model&format=SafeTensor&size=pruned&fp=fp16)
|
||||
# `models/AnimateDiff/mm_sd_v15_v2.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt)
|
||||
# `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth)
|
||||
# `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth)
|
||||
# `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth)
|
||||
# `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth)
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_textual_inversions("models/textual_inversion")
|
||||
model_manager.load_models([
|
||||
"models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors",
|
||||
"models/AnimateDiff/mm_sd_v15_v2.ckpt",
|
||||
"models/ControlNet/control_v11p_sd15_lineart.pth",
|
||||
"models/ControlNet/control_v11f1e_sd15_tile.pth",
|
||||
])
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="lineart",
|
||||
model_path="models/ControlNet/control_v11p_sd15_lineart.pth",
|
||||
scale=1.0
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/control_v11f1e_sd15_tile.pth",
|
||||
scale=0.5
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Load video (we only use 16 frames in this example for testing)
|
||||
video = VideoData(video_file="input_video.mp4", height=1536, width=1536)
|
||||
input_video = [video[i] for i in range(16)]
|
||||
|
||||
# Toon shading
|
||||
torch.manual_seed(0)
|
||||
output_video = pipe(
|
||||
prompt="best quality, perfect anime illustration, light, a girl is dancing, smile, solo",
|
||||
negative_prompt="verybadimagenegative_v1.3",
|
||||
cfg_scale=5, clip_skip=2,
|
||||
controlnet_frames=input_video, num_frames=len(input_video),
|
||||
num_inference_steps=10, height=1536, width=1536,
|
||||
vram_limit_level=0,
|
||||
)
|
||||
|
||||
# Save images and video
|
||||
save_frames(output_video, "output_frames")
|
||||
save_video(output_video, "output_video.mp4", fps=16)
|
||||
34
examples/sdxl_text_to_image.py
Normal file
34
examples/sdxl_text_to_image.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from diffsynth import ModelManager, SDXLImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
# `models/stable_diffusion_xl/bluePencilXL_v200.safetensors`: [link](https://civitai.com/api/download/models/245614?type=Model&format=SafeTensor&size=pruned&fp=fp16)
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models(["models/stable_diffusion_xl/bluePencilXL_v200.safetensors"])
|
||||
pipe = SDXLImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=6,
|
||||
height=1024, width=1024, num_inference_steps=60,
|
||||
)
|
||||
image.save("1024.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=6,
|
||||
input_image=image.resize((2048, 2048)),
|
||||
height=2048, width=2048, num_inference_steps=60, denoising_strength=0.5
|
||||
)
|
||||
image.save("2048.jpg")
|
||||
|
||||
31
examples/sdxl_turbo.py
Normal file
31
examples/sdxl_turbo.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from diffsynth import ModelManager, SDXLImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
# `models/stable_diffusion_xl_turbo/sd_xl_turbo_1.0_fp16.safetensors`: [link](https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/sd_xl_turbo_1.0_fp16.safetensors)
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models(["models/stable_diffusion_xl_turbo/sd_xl_turbo_1.0_fp16.safetensors"])
|
||||
pipe = SDXLImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Text to image
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="black car",
|
||||
# Do not modify the following parameters!
|
||||
cfg_scale=1, height=512, width=512, num_inference_steps=1, progress_bar_cmd=lambda x:x
|
||||
)
|
||||
image.save(f"black_car.jpg")
|
||||
|
||||
# Image to image
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="red car",
|
||||
input_image=image, denoising_strength=0.7,
|
||||
# Do not modify the following parameters!
|
||||
cfg_scale=1, height=512, width=512, num_inference_steps=1, progress_bar_cmd=lambda x:x
|
||||
)
|
||||
image.save(f"black_car_to_red_car.jpg")
|
||||
@@ -5,59 +5,61 @@ import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.prompts import SDXLPrompter, SDPrompter
|
||||
from diffsynth.pipelines import SDXLPipeline, SDPipeline
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
||||
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||
config = {
|
||||
"Stable Diffusion": {
|
||||
"model_folder": "models/stable_diffusion",
|
||||
"pipeline_class": SDImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion XL": {
|
||||
"model_folder": "models/stable_diffusion_xl",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion XL Turbo": {
|
||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {
|
||||
"negative_prompt": "",
|
||||
"cfg_scale": 1.0,
|
||||
"num_inference_steps": 1,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_model_list(folder):
|
||||
def load_model_list(model_type):
|
||||
folder = config[model_type]["model_folder"]
|
||||
file_list = os.listdir(folder)
|
||||
file_list = [i for i in file_list if i.endswith(".safetensors")]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def detect_model_path(sd_model_path, sdxl_model_path):
|
||||
if sd_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion", sd_model_path)
|
||||
elif sdxl_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
|
||||
else:
|
||||
model_path = None
|
||||
return model_path
|
||||
|
||||
|
||||
def load_model(sd_model_path, sdxl_model_path):
|
||||
if sd_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion", sd_model_path)
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors(model_path)
|
||||
prompter = SDPrompter()
|
||||
pipeline = SDPipeline()
|
||||
elif sdxl_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors(model_path)
|
||||
prompter = SDXLPrompter()
|
||||
pipeline = SDXLPipeline()
|
||||
else:
|
||||
return None, None, None, None
|
||||
return model_path, model_manager, prompter, pipeline
|
||||
|
||||
|
||||
def release_model():
|
||||
if "model_manager" in st.session_state:
|
||||
st.session_state["model_manager"].to("cpu")
|
||||
del st.session_state["loaded_model_path"]
|
||||
del st.session_state["model_manager"]
|
||||
del st.session_state["prompter"]
|
||||
del st.session_state["pipeline"]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_model(model_path)
|
||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
st.session_state.loaded_model_path = model_path
|
||||
st.session_state.model_manager = model_manager
|
||||
st.session_state.pipeline = pipeline
|
||||
return model_manager, pipeline
|
||||
|
||||
|
||||
def use_output_image_as_input():
|
||||
# Search for input image
|
||||
output_image_id = 0
|
||||
@@ -102,68 +104,64 @@ def show_output_image(image):
|
||||
|
||||
|
||||
column_input, column_output = st.columns(2)
|
||||
|
||||
# with column_input:
|
||||
with st.sidebar:
|
||||
# Select a model
|
||||
with st.expander("Model", expanded=True):
|
||||
sd_model_list = ["None"] + load_model_list("models/stable_diffusion")
|
||||
sd_model_path = st.selectbox(
|
||||
"Stable Diffusion", sd_model_list
|
||||
)
|
||||
sdxl_model_list = ["None"] + load_model_list("models/stable_diffusion_xl")
|
||||
sdxl_model_path = st.selectbox(
|
||||
"Stable Diffusion XL", sdxl_model_list
|
||||
)
|
||||
model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"])
|
||||
fixed_parameters = config[model_type]["fixed_parameters"]
|
||||
model_path_list = ["None"] + load_model_list(model_type)
|
||||
model_path = st.selectbox("Model path", model_path_list)
|
||||
|
||||
# Load the model
|
||||
model_path = detect_model_path(sd_model_path, sdxl_model_path)
|
||||
if model_path is None:
|
||||
st.markdown("No models selected.")
|
||||
if model_path == "None":
|
||||
# No models are selected. Release VRAM.
|
||||
st.markdown("No models are selected.")
|
||||
release_model()
|
||||
elif st.session_state.get("loaded_model_path", "") != model_path:
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
release_model()
|
||||
model_path, model_manager, prompter, pipeline = load_model(sd_model_path, sdxl_model_path)
|
||||
st.session_state.loaded_model_path = model_path
|
||||
st.session_state.model_manager = model_manager
|
||||
st.session_state.prompter = prompter
|
||||
st.session_state.pipeline = pipeline
|
||||
else:
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
model_path, model_manager, prompter, pipeline = (
|
||||
st.session_state.loaded_model_path,
|
||||
st.session_state.model_manager,
|
||||
st.session_state.prompter,
|
||||
st.session_state.pipeline,
|
||||
)
|
||||
# A model is selected.
|
||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
||||
if st.session_state.get("loaded_model_path", "") != model_path:
|
||||
# The loaded model is not the selected model. Reload it.
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
release_model()
|
||||
model_manager, pipeline = load_model(model_type, model_path)
|
||||
else:
|
||||
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
||||
|
||||
# Show parameters
|
||||
with st.expander("Prompt", expanded=True):
|
||||
column_positive, column_negative = st.columns(2)
|
||||
prompt = st.text_area("Positive prompt")
|
||||
negative_prompt = st.text_area("Negative prompt")
|
||||
with st.expander("Classifier-free guidance", expanded=True):
|
||||
use_cfg = st.checkbox("Use classifier-free guidance", value=True)
|
||||
if use_cfg:
|
||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, step=0.1, value=7.5)
|
||||
if "negative_prompt" in fixed_parameters:
|
||||
negative_prompt = fixed_parameters["negative_prompt"]
|
||||
else:
|
||||
cfg_scale = 1.0
|
||||
with st.expander("Inference steps", expanded=True):
|
||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20, label_visibility="hidden")
|
||||
with st.expander("Image size", expanded=True):
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
with st.expander("Seed", expanded=True):
|
||||
negative_prompt = st.text_area("Negative prompt")
|
||||
if "cfg_scale" in fixed_parameters:
|
||||
cfg_scale = fixed_parameters["cfg_scale"]
|
||||
else:
|
||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
|
||||
with st.expander("Image", expanded=True):
|
||||
if "num_inference_steps" in fixed_parameters:
|
||||
num_inference_steps = fixed_parameters["num_inference_steps"]
|
||||
else:
|
||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
|
||||
if "height" in fixed_parameters:
|
||||
height = fixed_parameters["height"]
|
||||
else:
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
if "width" in fixed_parameters:
|
||||
width = fixed_parameters["width"]
|
||||
else:
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
num_images = st.number_input("Number of images", value=2)
|
||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
||||
if use_fixed_seed:
|
||||
seed = st.number_input("Random seed", value=0, label_visibility="hidden")
|
||||
with st.expander("Number of images", expanded=True):
|
||||
num_images = st.number_input("Number of images", value=4, label_visibility="hidden")
|
||||
with st.expander("Tile (for high resolution)", expanded=True):
|
||||
tiled = st.checkbox("Use tile", value=False)
|
||||
tile_size = st.select_slider("Tile size", options=[64, 128], value=64)
|
||||
tile_stride = st.select_slider("Tile stride", options=[8, 16, 32, 64], value=32)
|
||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
||||
|
||||
# Other fixed parameters
|
||||
denoising_strength = 1.0
|
||||
repetition = 1
|
||||
|
||||
|
||||
# Show input image
|
||||
@@ -180,7 +178,7 @@ with column_input:
|
||||
if upload_image is not None:
|
||||
st.session_state["input_image"] = Image.open(upload_image)
|
||||
elif create_white_board:
|
||||
st.session_state["input_image"] = Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255)
|
||||
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
|
||||
else:
|
||||
use_output_image_as_input()
|
||||
|
||||
@@ -202,6 +200,7 @@ with column_input:
|
||||
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
|
||||
with st.container(border=True):
|
||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
|
||||
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
|
||||
with st.container(border=True):
|
||||
input_width, input_height = input_image.size
|
||||
canvas_result = st_canvas(
|
||||
@@ -225,34 +224,32 @@ with column_output:
|
||||
image_columns = st.columns(num_image_columns)
|
||||
|
||||
# Run
|
||||
if (run_button or auto_update) and model_path is not None:
|
||||
if (run_button or auto_update) and model_path != "None":
|
||||
|
||||
if not use_fixed_seed:
|
||||
torch.manual_seed(np.random.randint(0, 10**9))
|
||||
if input_image is not None:
|
||||
input_image = input_image.resize((width, height))
|
||||
if canvas_result.image_data is not None:
|
||||
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
|
||||
|
||||
output_images = []
|
||||
for image_id in range(num_images):
|
||||
for image_id in range(num_images * repetition):
|
||||
if use_fixed_seed:
|
||||
torch.manual_seed(seed + image_id)
|
||||
if input_image is not None:
|
||||
input_image = input_image.resize((width, height))
|
||||
if canvas_result.image_data is not None:
|
||||
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
|
||||
else:
|
||||
denoising_strength = 1.0
|
||||
torch.manual_seed(np.random.randint(0, 10**9))
|
||||
if image_id >= num_images:
|
||||
input_image = output_images[image_id - num_images]
|
||||
with image_columns[image_id % num_image_columns]:
|
||||
progress_bar = st.progress(0.0)
|
||||
progress_bar_st = st.progress(0.0)
|
||||
image = pipeline(
|
||||
model_manager, prompter,
|
||||
prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
||||
height=height, width=width,
|
||||
init_image=input_image, denoising_strength=denoising_strength,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
progress_bar_st=progress_bar
|
||||
input_image=input_image, denoising_strength=denoising_strength,
|
||||
progress_bar_st=progress_bar_st
|
||||
)
|
||||
output_images.append(image)
|
||||
progress_bar.progress(1.0)
|
||||
progress_bar_st.progress(1.0)
|
||||
show_output_image(image)
|
||||
st.session_state["output_images"] = output_images
|
||||
|
||||
4
pages/2_Video_Creator.py
Normal file
4
pages/2_Video_Creator.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
|
||||
st.markdown("# Coming soon")
|
||||
Reference in New Issue
Block a user