diff --git a/Diffsynth_Studio.py b/Diffsynth_Studio.py deleted file mode 100644 index dcdc205..0000000 --- a/Diffsynth_Studio.py +++ /dev/null @@ -1,264 +0,0 @@ -import torch, os, io -import numpy as np -from PIL import Image -import streamlit as st -st.set_page_config(layout="wide") -from streamlit_drawable_canvas import st_canvas -from diffsynth.models import ModelManager -from diffsynth.prompts import SDXLPrompter, SDPrompter -from diffsynth.pipelines import SDXLPipeline, SDPipeline - - -torch.cuda.set_per_process_memory_fraction(0.999, 0) - - -@st.cache_data -def load_model_list(folder): - file_list = os.listdir(folder) - file_list = [i for i in file_list if i.endswith(".safetensors")] - file_list = sorted(file_list) - return file_list - - -def detect_model_path(sd_model_path, sdxl_model_path): - if sd_model_path != "None": - model_path = os.path.join("models/stable_diffusion", sd_model_path) - elif sdxl_model_path != "None": - model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path) - else: - model_path = None - return model_path - - -def load_model(sd_model_path, sdxl_model_path): - if sd_model_path != "None": - model_path = os.path.join("models/stable_diffusion", sd_model_path) - model_manager = ModelManager() - model_manager.load_from_safetensors(model_path) - prompter = SDPrompter() - pipeline = SDPipeline() - elif sdxl_model_path != "None": - model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path) - model_manager = ModelManager() - model_manager.load_from_safetensors(model_path) - prompter = SDXLPrompter() - pipeline = SDXLPipeline() - else: - return None, None, None, None - return model_path, model_manager, prompter, pipeline - - -def release_model(): - if "model_manager" in st.session_state: - st.session_state["model_manager"].to("cpu") - del st.session_state["loaded_model_path"] - del st.session_state["model_manager"] - del st.session_state["prompter"] - del st.session_state["pipeline"] - torch.cuda.empty_cache() - - -def use_output_image_as_input(): - # Search for input image - output_image_id = 0 - selected_output_image = None - while True: - if f"use_output_as_input_{output_image_id}" not in st.session_state: - break - if st.session_state[f"use_output_as_input_{output_image_id}"]: - selected_output_image = st.session_state["output_images"][output_image_id] - break - output_image_id += 1 - if selected_output_image is not None: - st.session_state["input_image"] = selected_output_image - - -def apply_stroke_to_image(stroke_image, image): - image = np.array(image.convert("RGB")).astype(np.float32) - height, width, _ = image.shape - - stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32) - weight = stroke_image[:, :, -1:] / 255 - stroke_image = stroke_image[:, :, :-1] - - image = stroke_image * weight + image * (1 - weight) - image = np.clip(image, 0, 255).astype(np.uint8) - image = Image.fromarray(image) - return image - - -@st.cache_data -def image2bits(image): - image_byte = io.BytesIO() - image.save(image_byte, format="PNG") - image_byte = image_byte.getvalue() - return image_byte - - -def show_output_image(image): - st.image(image, use_column_width="always") - st.button("Use it as input image", key=f"use_output_as_input_{image_id}") - st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}") - - -column_input, column_output = st.columns(2) - -# with column_input: -with st.sidebar: - # Select a model - with st.expander("Model", expanded=True): - sd_model_list = ["None"] + load_model_list("models/stable_diffusion") - sd_model_path = st.selectbox( - "Stable Diffusion", sd_model_list - ) - sdxl_model_list = ["None"] + load_model_list("models/stable_diffusion_xl") - sdxl_model_path = st.selectbox( - "Stable Diffusion XL", sdxl_model_list - ) - - # Load the model - model_path = detect_model_path(sd_model_path, sdxl_model_path) - if model_path is None: - st.markdown("No models selected.") - release_model() - elif st.session_state.get("loaded_model_path", "") != model_path: - st.markdown(f"Using model at {model_path}.") - release_model() - model_path, model_manager, prompter, pipeline = load_model(sd_model_path, sdxl_model_path) - st.session_state.loaded_model_path = model_path - st.session_state.model_manager = model_manager - st.session_state.prompter = prompter - st.session_state.pipeline = pipeline - else: - st.markdown(f"Using model at {model_path}.") - model_path, model_manager, prompter, pipeline = ( - st.session_state.loaded_model_path, - st.session_state.model_manager, - st.session_state.prompter, - st.session_state.pipeline, - ) - - # Show parameters - with st.expander("Prompt", expanded=True): - column_positive, column_negative = st.columns(2) - prompt = st.text_area("Positive prompt") - negative_prompt = st.text_area("Negative prompt") - with st.expander("Classifier-free guidance", expanded=True): - use_cfg = st.checkbox("Use classifier-free guidance", value=True) - if use_cfg: - cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, step=0.1, value=7.5) - else: - cfg_scale = 1.0 - with st.expander("Inference steps", expanded=True): - num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20, label_visibility="hidden") - with st.expander("Image size", expanded=True): - height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512) - width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512) - with st.expander("Seed", expanded=True): - use_fixed_seed = st.checkbox("Use fixed seed", value=False) - if use_fixed_seed: - seed = st.number_input("Random seed", value=0, label_visibility="hidden") - with st.expander("Number of images", expanded=True): - num_images = st.number_input("Number of images", value=4, label_visibility="hidden") - with st.expander("Tile (for high resolution)", expanded=True): - tiled = st.checkbox("Use tile", value=False) - tile_size = st.select_slider("Tile size", options=[64, 128], value=64) - tile_stride = st.select_slider("Tile stride", options=[8, 16, 32, 64], value=32) - - -# Show input image -with column_input: - with st.expander("Input image (Optional)", expanded=True): - with st.container(border=True): - column_white_board, column_upload_image = st.columns([1, 2]) - with column_white_board: - create_white_board = st.button("Create white board") - delete_input_image = st.button("Delete input image") - with column_upload_image: - upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image") - - if upload_image is not None: - st.session_state["input_image"] = Image.open(upload_image) - elif create_white_board: - st.session_state["input_image"] = Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255) - else: - use_output_image_as_input() - - if delete_input_image and "input_image" in st.session_state: - del st.session_state.input_image - if delete_input_image and "upload_image" in st.session_state: - del st.session_state.upload_image - - input_image = st.session_state.get("input_image", None) - if input_image is not None: - with st.container(border=True): - column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1]) - with column_drawing_mode: - drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1) - with column_color_1: - stroke_color = st.color_picker("Stroke color") - with column_color_2: - fill_color = st.color_picker("Fill color") - stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10) - with st.container(border=True): - denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7) - with st.container(border=True): - input_width, input_height = input_image.size - canvas_result = st_canvas( - fill_color=fill_color, - stroke_width=stroke_width, - stroke_color=stroke_color, - background_color="rgba(255, 255, 255, 0)", - background_image=input_image, - update_streamlit=True, - height=int(512 / input_width * input_height), - width=512, - drawing_mode=drawing_mode, - key="canvas" - ) - - -with column_output: - run_button = st.button("Generate image", type="primary") - auto_update = st.checkbox("Auto update", value=False) - num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2) - image_columns = st.columns(num_image_columns) - - # Run - if (run_button or auto_update) and model_path is not None: - - if not use_fixed_seed: - torch.manual_seed(np.random.randint(0, 10**9)) - - output_images = [] - for image_id in range(num_images): - if use_fixed_seed: - torch.manual_seed(seed + image_id) - if input_image is not None: - input_image = input_image.resize((width, height)) - if canvas_result.image_data is not None: - input_image = apply_stroke_to_image(canvas_result.image_data, input_image) - else: - denoising_strength = 1.0 - with image_columns[image_id % num_image_columns]: - progress_bar = st.progress(0.0) - image = pipeline( - model_manager, prompter, - prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale, - num_inference_steps=num_inference_steps, - height=height, width=width, - init_image=input_image, denoising_strength=denoising_strength, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, - progress_bar_st=progress_bar - ) - output_images.append(image) - progress_bar.progress(1.0) - show_output_image(image) - st.session_state["output_images"] = output_images - - elif "output_images" in st.session_state: - for image_id in range(len(st.session_state.output_images)): - with image_columns[image_id % num_image_columns]: - image = st.session_state.output_images[image_id] - progress_bar = st.progress(1.0) - show_output_image(image) diff --git a/README-zh.md b/README-zh.md deleted file mode 100644 index 7b5ba10..0000000 --- a/README-zh.md +++ /dev/null @@ -1,53 +0,0 @@ -# DiffSynth Studio - -## 介绍 - -DiffSynth 是一个全新的 Diffusion 引擎,我们重构了 Text Encoder、UNet、VAE 等架构,保持与开源社区模型兼容性的同时,提升了计算性能。目前这个版本仅仅是一个初始版本,实现了文生图和图生图功能,支持 SD 和 SDXL 架构。未来我们计划基于这个全新的代码库开发更多有趣的功能。 - -## 安装 - -如果你只想在 Python 代码层面调用 DiffSynth Studio,你只需要安装 `torch`(深度学习框架)和 `transformers`(仅用于实现分词器)。 - -``` -pip install torch transformers -``` - -如果你想使用 UI,还需要额外安装 `streamlit`(一个 webui 框架)和 `streamlit-drawable-canvas`(用于图生图画板)。 - -``` -pip install streamlit streamlit-drawable-canvas -``` - -## 使用 - -通过 Python 代码调用 - -```python -from diffsynth.models import ModelManager -from diffsynth.prompts import SDPrompter, SDXLPrompter -from diffsynth.pipelines import SDPipeline, SDXLPipeline - - -model_manager = ModelManager() -model_manager.load_from_safetensors("xxxxxxxx.safetensors") -prompter = SDPrompter() -pipe = SDPipeline() - -prompt = "a girl" -negative_prompt = "" - -image = pipe( - model_manager, prompter, - prompt, negative_prompt=negative_prompt, - num_inference_steps=20, height=512, width=512, -) -image.save("image.png") -``` - -如果需要用 SDXL 架构模型,请把 `SDPrompter`、`SDPipeline` 换成 `SDXLPrompter`, `SDXLPipeline`。 - -当然,你也可以使用我们提供的 UI,但请注意,我们的 UI 程序很简单,且未来可能会大幅改变。 - -``` -python -m streamlit run Diffsynth_Studio.py -``` diff --git a/README.md b/README.md index e5ab11b..5b167ea 100644 --- a/README.md +++ b/README.md @@ -1,53 +1,75 @@ # DiffSynth Studio -## 介绍 +## Introduction -DiffSynth is a new Diffusion engine. We have restructured architectures like Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting text-to-image and image-to-image functionalities, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase. +This branch supports video-to-video translation and is still under development. -## 安装 - -If you only want to use DiffSynth Studio at the Python code level, you just need to install torch (a deep learning framework) and transformers (only used for implementing a tokenizer). +## Installation ``` -pip install torch transformers +conda env create -f environment.yml ``` -If you wish to use the UI, you'll also need to additionally install `streamlit` (a web UI framework) and `streamlit-drawable-canvas` (used for the image-to-image canvas). +## Usage -``` -pip install streamlit streamlit-drawable-canvas -``` +### Example 1: Toon Shading -## 使用 +You can download the models as follows: -Use DiffSynth Studio in Python +* `models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors`: [link](https://civitai.com/api/download/models/266360?type=Model&format=SafeTensor&size=pruned&fp=fp16) +* `models/AnimateDiff/mm_sd_v15_v2.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt) +* `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth) +* `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth) +* `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth) +* `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth) ```python -from diffsynth.models import ModelManager -from diffsynth.prompts import SDPrompter, SDXLPrompter -from diffsynth.pipelines import SDPipeline, SDXLPipeline +from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames +import torch -model_manager = ModelManager() -model_manager.load_from_safetensors("xxxxxxxx.safetensors") -prompter = SDPrompter() -pipe = SDPipeline() - -prompt = "a girl" -negative_prompt = "" - -image = pipe( - model_manager, prompter, - prompt, negative_prompt=negative_prompt, - num_inference_steps=20, height=512, width=512, +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_textual_inversions("models/textual_inversion") +model_manager.load_models([ + "models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors", + "models/AnimateDiff/mm_sd_v15_v2.ckpt", + "models/ControlNet/control_v11p_sd15_lineart.pth", + "models/ControlNet/control_v11f1e_sd15_tile.pth", +]) +pipe = SDVideoPipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id="lineart", + model_path="models/ControlNet/control_v11p_sd15_lineart.pth", + scale=1.0 + ), + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/control_v11f1e_sd15_tile.pth", + scale=0.5 + ), + ] ) -image.save("image.png") + +# Load video +video = VideoData(video_file="data/66dance/raw.mp4", height=1536, width=1536) +input_video = [video[i] for i in range(40*60, 40*60+16)] + +# Toon shading +torch.manual_seed(0) +output_video = pipe( + prompt="best quality, perfect anime illustration, light, a girl is dancing, smile, solo", + negative_prompt="verybadimagenegative_v1.3", + cfg_scale=5, clip_skip=2, + controlnet_frames=input_video, num_frames=16, + num_inference_steps=10, height=1536, width=1536, + vram_limit_level=0, +) + +# Save images and video +save_frames(output_video, "data/text2video/frames") +save_video(output_video, "data/text2video/video.mp4", fps=16) ``` -If you want to use SDXL architecture models, replace `SDPrompter` and `SDPipeline` with `SDXLPrompter` and `SDXLPipeline`, respectively. - -Of course, you can also use the UI we provide. The UI is simple but may be changed in the future. - -``` -python -m streamlit run Diffsynth_Studio.py -``` diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py new file mode 100644 index 0000000..0f23517 --- /dev/null +++ b/diffsynth/__init__.py @@ -0,0 +1,6 @@ +from .data import * +from .models import * +from .prompts import * +from .schedulers import * +from .pipelines import * +from .controlnets import * diff --git a/diffsynth/controlnets/__init__.py b/diffsynth/controlnets/__init__.py new file mode 100644 index 0000000..b08ba4c --- /dev/null +++ b/diffsynth/controlnets/__init__.py @@ -0,0 +1,2 @@ +from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager +from .processors import Annotator diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py new file mode 100644 index 0000000..d279bd7 --- /dev/null +++ b/diffsynth/controlnets/controlnet_unit.py @@ -0,0 +1,48 @@ +import torch +import numpy as np +from .processors import Processor_id + + +class ControlNetConfigUnit: + def __init__(self, processor_id: Processor_id, model_path, scale=1.0): + self.processor_id = processor_id + self.model_path = model_path + self.scale = scale + + +class ControlNetUnit: + def __init__(self, processor, model, scale=1.0): + self.processor = processor + self.model = model + self.scale = scale + + +class MultiControlNetManager: + def __init__(self, controlnet_units=[]): + self.processors = [unit.processor for unit in controlnet_units] + self.models = [unit.model for unit in controlnet_units] + self.scales = [unit.scale for unit in controlnet_units] + + def process_image(self, image, return_image=False): + processed_image = [ + processor(image) + for processor in self.processors + ] + if return_image: + return processed_image + processed_image = torch.concat([ + torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) + for image_ in processed_image + ], dim=0) + return processed_image + + def __call__(self, sample, timestep, encoder_hidden_states, conditionings): + res_stack = None + for conditioning, model, scale in zip(conditionings, self.models, self.scales): + res_stack_ = model(sample, timestep, encoder_hidden_states, conditioning) + res_stack_ = [res * scale for res in res_stack_] + if res_stack is None: + res_stack = res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + return res_stack diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py new file mode 100644 index 0000000..36bceb3 --- /dev/null +++ b/diffsynth/controlnets/processors.py @@ -0,0 +1,50 @@ +from typing_extensions import Literal, TypeAlias +import warnings +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from controlnet_aux.processor import ( + CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector + ) + + +Processor_id: TypeAlias = Literal[ + "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" +] + +class Annotator: + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=512): + if processor_id == "canny": + self.processor = CannyDetector() + elif processor_id == "depth": + self.processor = MidasDetector.from_pretrained(model_path) + elif processor_id == "softedge": + self.processor = HEDdetector.from_pretrained(model_path) + elif processor_id == "lineart": + self.processor = LineartDetector.from_pretrained(model_path) + elif processor_id == "lineart_anime": + self.processor = LineartAnimeDetector.from_pretrained(model_path) + elif processor_id == "openpose": + self.processor = OpenposeDetector.from_pretrained(model_path) + elif processor_id == "tile": + self.processor = None + else: + raise ValueError(f"Unsupported processor_id: {processor_id}") + + self.processor_id = processor_id + self.detect_resolution = detect_resolution + + def __call__(self, image): + width, height = image.size + if self.processor_id == "openpose": + kwargs = { + "include_body": True, + "include_hand": True, + "include_face": True + } + else: + kwargs = {} + if self.processor is not None: + image = self.processor(image, detect_resolution=self.detect_resolution, image_resolution=min(width, height), **kwargs) + image = image.resize((width, height)) + return image + diff --git a/diffsynth/data/__init__.py b/diffsynth/data/__init__.py new file mode 100644 index 0000000..de09a29 --- /dev/null +++ b/diffsynth/data/__init__.py @@ -0,0 +1 @@ +from .video import VideoData, save_video, save_frames diff --git a/diffsynth/data/video.py b/diffsynth/data/video.py new file mode 100644 index 0000000..16e1918 --- /dev/null +++ b/diffsynth/data/video.py @@ -0,0 +1,148 @@ +import imageio, os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return Image.open(self.file_list[item]).convert("RGB") + + def __del__(self): + pass + + +def crop_and_resize(image, height, width): + image = np.array(image) + image_height, image_width, _ = image.shape + if image_height / image_width < height / width: + croped_width = int(image_height / height * width) + left = (image_width - croped_width) // 2 + image = image[:, left: left+croped_width] + image = Image.fromarray(image).resize((width, height)) + else: + croped_height = int(image_width / width * height) + left = (image_height - croped_height) // 2 + image = image[left: left+croped_height, :] + image = Image.fromarray(image).resize((width, height)) + return image + + +class VideoData: + def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.set_shape(height, width) + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + height, width, _ = self.__getitem__(0).shape + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + width, height = frame.size + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = crop_and_resize(frame, self.height, self.width) + return frame + + def __del__(self): + pass + + def save_images(self, folder): + os.makedirs(folder, exist_ok=True) + for i in tqdm(range(self.__len__()), desc="Saving images"): + frame = self.__getitem__(i) + frame.save(os.path.join(folder, f"{i}.png")) + + +def save_video(frames, save_path, fps, quality=9): + writer = imageio.get_writer(save_path, fps=fps, quality=quality) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + +def save_frames(frames, save_path): + os.makedirs(save_path, exist_ok=True) + for i, frame in enumerate(tqdm(frames, desc="Saving images")): + frame.save(os.path.join(save_path, f"{i}.png")) diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index a169ccd..41fecb5 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -1,4 +1,4 @@ -import torch +import torch, os from safetensors import safe_open from .sd_text_encoder import SDTextEncoder @@ -11,12 +11,18 @@ from .sdxl_unet import SDXLUNet from .sdxl_vae_decoder import SDXLVAEDecoder from .sdxl_vae_encoder import SDXLVAEEncoder +from .sd_controlnet import SDControlNet + +from .sd_motion import SDMotionModel + class ModelManager: - def __init__(self, torch_type=torch.float16, device="cuda"): - self.torch_type = torch_type + def __init__(self, torch_dtype=torch.float16, device="cuda"): + self.torch_dtype = torch_dtype self.device = device self.model = {} + self.model_path = {} + self.textual_inversion_dict = {} def is_stabe_diffusion_xl(self, state_dict): param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight" @@ -25,7 +31,15 @@ class ModelManager: def is_stable_diffusion(self, state_dict): return True - def load_stable_diffusion(self, state_dict, components=None): + def is_controlnet(self, state_dict): + param_name = "control_model.time_embed.0.weight" + return param_name in state_dict + + def is_animatediff(self, state_dict): + param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight" + return param_name in state_dict + + def load_stable_diffusion(self, state_dict, components=None, file_path=""): component_dict = { "text_encoder": SDTextEncoder, "unet": SDUNet, @@ -36,11 +50,24 @@ class ModelManager: if components is None: components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"] for component in components: - self.model[component] = component_dict[component]() - self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) - self.model[component].to(self.torch_type).to(self.device) + if component == "text_encoder": + # Add additional token embeddings to text encoder + token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]] + for keyword in self.textual_inversion_dict: + _, embeddings = self.textual_inversion_dict[keyword] + token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype)) + token_embeddings = torch.concat(token_embeddings, dim=0) + state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings + self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0]) + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + else: + self.model[component] = component_dict[component]() + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + self.model_path[component] = file_path - def load_stable_diffusion_xl(self, state_dict, components=None): + def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""): component_dict = { "text_encoder": SDXLTextEncoder, "text_encoder_2": SDXLTextEncoder2, @@ -60,18 +87,86 @@ class ModelManager: # I do not know how to solve this problem. self.model[component].to(torch.float32).to(self.device) else: - self.model[component].to(self.torch_type).to(self.device) - - def load_from_safetensors(self, file_path, components=None): - state_dict = load_state_dict_from_safetensors(file_path) - if self.is_stabe_diffusion_xl(state_dict): - self.load_stable_diffusion_xl(state_dict, components=components) - elif self.is_stable_diffusion(state_dict): - self.load_stable_diffusion(state_dict, components=components) + self.model[component].to(self.torch_dtype).to(self.device) + self.model_path[component] = file_path + def load_controlnet(self, state_dict, file_path=""): + component = "controlnet" + if component not in self.model: + self.model[component] = [] + self.model_path[component] = [] + model = SDControlNet() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component].append(model) + self.model_path[component].append(file_path) + + def load_animatediff(self, state_dict, file_path=""): + component = "motion_modules" + model = SDMotionModel() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def search_for_embeddings(self, state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += self.search_for_embeddings(state_dict[k]) + return embeddings + + def load_textual_inversions(self, folder): + # Store additional tokens here + self.textual_inversion_dict = {} + + # Load every textual inversion file + for file_name in os.listdir(folder): + keyword = os.path.splitext(file_name)[0] + state_dict = load_state_dict(os.path.join(folder, file_name)) + + # Search for embeddings + for embeddings in self.search_for_embeddings(state_dict): + if len(embeddings.shape) == 2 and embeddings.shape[1] == 768: + tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])] + self.textual_inversion_dict[keyword] = (tokens, embeddings) + break + + def load_model(self, file_path, components=None): + state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype) + if self.is_animatediff(state_dict): + self.load_animatediff(state_dict, file_path=file_path) + elif self.is_controlnet(state_dict): + self.load_controlnet(state_dict, file_path=file_path) + elif self.is_stabe_diffusion_xl(state_dict): + self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path) + elif self.is_stable_diffusion(state_dict): + self.load_stable_diffusion(state_dict, components=components, file_path=file_path) + + def load_models(self, file_path_list): + for file_path in file_path_list: + self.load_model(file_path) + def to(self, device): for component in self.model: - self.model[component].to(device) + if isinstance(self.model[component], list): + for model in self.model[component]: + model.to(device) + else: + self.model[component].to(device) + + def get_model_with_model_path(self, model_path): + for component in self.model_path: + if isinstance(self.model_path[component], str): + if os.path.samefile(self.model_path[component], model_path): + return self.model[component] + elif isinstance(self.model_path[component], list): + for i, model_path_ in enumerate(self.model_path[component]): + if os.path.samefile(model_path_, model_path): + return self.model[component][i] + raise ValueError(f"Please load model {model_path} before you use it.") def __getattr__(self, __name): if __name in self.model: @@ -80,16 +175,28 @@ class ModelManager: return super.__getattribute__(__name) -def load_state_dict_from_safetensors(file_path): +def load_state_dict(file_path, torch_dtype=None): + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None): state_dict = {} with safe_open(file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) return state_dict -def load_state_dict_from_bin(file_path): - return torch.load(file_path, map_location="cpu") +def load_state_dict_from_bin(file_path, torch_dtype=None): + state_dict = torch.load(file_path, map_location="cpu") + if torch_dtype is not None: + state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict} + return state_dict def search_parameter(param, state_dict): diff --git a/diffsynth/models/attention.py b/diffsynth/models/attention.py index e1dbd78..1a2c110 100644 --- a/diffsynth/models/attention.py +++ b/diffsynth/models/attention.py @@ -1,4 +1,15 @@ import torch +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value class Attention(torch.nn.Module): @@ -15,7 +26,7 @@ class Attention(torch.nn.Module): self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) - def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -36,3 +47,30 @@ class Attention(torch.nn.Module): hidden_states = self.to_out(hidden_states) return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask) \ No newline at end of file diff --git a/diffsynth/models/sd_controlnet.py b/diffsynth/models/sd_controlnet.py new file mode 100644 index 0000000..f43a6de --- /dev/null +++ b/diffsynth/models/sd_controlnet.py @@ -0,0 +1,566 @@ +import torch +from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler + + +class ControlNetConditioningLayer(torch.nn.Module): + def __init__(self, channels = (3, 16, 32, 96, 256, 320)): + super().__init__() + self.blocks = torch.nn.ModuleList([]) + self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1)) + self.blocks.append(torch.nn.SiLU()) + for i in range(1, len(channels) - 2): + self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1)) + self.blocks.append(torch.nn.SiLU()) + self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2)) + self.blocks.append(torch.nn.SiLU()) + self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1)) + + def forward(self, conditioning): + for block in self.blocks: + conditioning = block(conditioning) + return conditioning + + +class SDControlNet(torch.nn.Module): + def __init__(self, global_pool=False): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) + + self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320)) + + self.blocks = torch.nn.ModuleList([ + # CrossAttnDownBlock2D + ResnetBlock(320, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768), + PushBlock(), + ResnetBlock(320, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768), + PushBlock(), + DownSampler(320), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768), + PushBlock(), + ResnetBlock(640, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768), + PushBlock(), + DownSampler(640), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768), + PushBlock(), + DownSampler(1280), + PushBlock(), + # DownBlock2D + ResnetBlock(1280, 1280, 1280), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768), + ResnetBlock(1280, 1280, 1280), + PushBlock() + ]) + + self.controlnet_blocks = torch.nn.ModuleList([ + torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1)), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + ]) + + self.global_pool = global_pool + + def forward(self, sample, timestep, encoder_hidden_states, conditioning): + # 1. time + time_emb = self.time_proj(timestep[None]).to(sample.dtype) + time_emb = self.time_embedding(time_emb) + time_emb = time_emb.repeat(sample.shape[0], 1) + + # 2. pre-process + hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 3. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 4. ControlNet blocks + controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)] + + # pool + if self.global_pool: + controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack] + + return controlnet_res_stack + + def state_dict_converter(self): + return SDControlNetStateDictConverter() + + +class SDControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'ResnetBlock', + 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock' + ] + + # controlnet_rename_dict + controlnet_rename_dict = { + "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight", + "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias", + "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight", + "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias", + "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight", + "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias", + "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight", + "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias", + "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight", + "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias", + "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight", + "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias", + "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight", + "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias", + "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight", + "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias", + } + + # Rename each parameter + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + if names[0] in ["conv_in", "conv_norm_out", "conv_out"]: + pass + elif name in controlnet_rename_dict: + names = controlnet_rename_dict[name].split(".") + elif names[0] == "controlnet_down_blocks": + names[0] = "controlnet_blocks" + elif names[0] == "controlnet_mid_block": + names = ["controlnet_blocks", "12", names[-1]] + elif names[0] in ["time_embedding", "add_embedding"]: + if names[0] == "add_embedding": + names[0] = "add_time_embedding" + names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]] + elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[0] == "mid_block": + names.insert(1, "0") + block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type_with_id = ".".join(names[:4]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:4]) + names = ["blocks", str(block_id[block_type])] + names[4:] + if "ff" in names: + ff_index = names.index("ff") + component = ".".join(names[ff_index:ff_index+3]) + component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] + names = names[:ff_index] + [component] + names[ff_index+3:] + if "to_out" in names: + names.pop(names.index("to_out") + 1) + else: + raise ValueError(f"Unknown parameters: {name}") + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + if rename_dict[name] in [ + "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias", + "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias" + ]: + continue + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "control_model.time_embed.0.weight": "time_embedding.0.weight", + "control_model.time_embed.0.bias": "time_embedding.0.bias", + "control_model.time_embed.2.weight": "time_embedding.2.weight", + "control_model.time_embed.2.bias": "time_embedding.2.bias", + "control_model.input_blocks.0.0.weight": "conv_in.weight", + "control_model.input_blocks.0.0.bias": "conv_in.bias", + "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight", + "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias", + "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight", + "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias", + "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight", + "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias", + "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight", + "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias", + "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight", + "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias", + "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight", + "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias", + "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight", + "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight", + "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias", + "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight", + "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias", + "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight", + "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias", + "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight", + "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias", + "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight", + "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias", + "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight", + "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias", + "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight", + "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias", + "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight", + "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight", + "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias", + "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight", + "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias", + "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight", + "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias", + "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight", + "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias", + "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight", + "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias", + "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight", + "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias", + "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight", + "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias", + "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight", + "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias", + "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight", + "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias", + "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight", + "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight", + "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias", + "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight", + "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias", + "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight", + "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias", + "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight", + "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias", + "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight", + "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias", + "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight", + "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias", + "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight", + "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias", + "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight", + "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight", + "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias", + "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight", + "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias", + "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight", + "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias", + "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight", + "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias", + "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight", + "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias", + "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight", + "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias", + "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight", + "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias", + "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight", + "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias", + "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight", + "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias", + "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight", + "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight", + "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias", + "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight", + "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias", + "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight", + "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias", + "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight", + "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias", + "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight", + "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias", + "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight", + "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias", + "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight", + "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias", + "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight", + "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight", + "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias", + "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight", + "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias", + "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight", + "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias", + "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight", + "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias", + "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight", + "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias", + "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight", + "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias", + "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight", + "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias", + "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight", + "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias", + "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight", + "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias", + "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight", + "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias", + "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight", + "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias", + "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight", + "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias", + "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight", + "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight", + "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight", + "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight", + "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight", + "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias", + "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight", + "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias", + "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight", + "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias", + "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight", + "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight", + "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight", + "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight", + "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight", + "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias", + "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight", + "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias", + "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight", + "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias", + "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight", + "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias", + "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight", + "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias", + "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight", + "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias", + "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight", + "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias", + "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight", + "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias", + "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight", + "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias", + "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight", + "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias", + "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight", + "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias", + "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight", + "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias", + "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight", + "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias", + "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight", + "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias", + "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight", + "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias", + "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight", + "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias", + "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight", + "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias", + "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight", + "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias", + "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight", + "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias", + "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight", + "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias", + "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight", + "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias", + "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight", + "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias", + "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight", + "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias", + "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight", + "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias", + "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight", + "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias", + "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight", + "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias", + "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight", + "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias", + "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight", + "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/sd_motion.py b/diffsynth/models/sd_motion.py new file mode 100644 index 0000000..b313e62 --- /dev/null +++ b/diffsynth/models/sd_motion.py @@ -0,0 +1,198 @@ +from .sd_unet import SDUNet, Attention, GEGLU +import torch +from einops import rearrange, repeat + + +class TemporalTransformerBlock(torch.nn.Module): + + def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): + super().__init__() + + # 1. Self-Attn + self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) + self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) + + # 2. Cross-Attn + self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) + self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) + + # 3. Feed-forward + self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.act_fn = GEGLU(dim, dim * 4) + self.ff = torch.nn.Linear(dim * 4, dim) + + + def forward(self, hidden_states, batch_size=1): + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) + attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) + attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) + attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) + attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.act_fn(norm_hidden_states) + ff_output = self.ff(ff_output) + hidden_states = ff_output + hidden_states + + return hidden_states + + +class TemporalBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + self.proj_in = torch.nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = torch.nn.ModuleList([ + TemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim + ) + for d in range(num_layers) + ]) + + self.proj_out = torch.nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + batch_size=batch_size + ) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class SDMotionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.motion_modules = torch.nn.ModuleList([ + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + ]) + self.call_block_id = { + 1: 0, + 4: 1, + 9: 2, + 12: 3, + 17: 4, + 20: 5, + 24: 6, + 26: 7, + 29: 8, + 32: 9, + 34: 10, + 36: 11, + 40: 12, + 43: 13, + 46: 14, + 50: 15, + 53: 16, + 56: 17, + 60: 18, + 63: 19, + 66: 20 + } + + def forward(self): + pass + + def state_dict_converter(self): + return SDMotionModelStateDictConverter() + + +class SDMotionModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "norm": "norm", + "proj_in": "proj_in", + "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", + "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", + "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", + "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", + "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", + "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", + "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", + "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", + "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", + "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", + "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", + "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", + "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", + "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", + "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", + "proj_out": "proj_out", + } + name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) + name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) + name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) + state_dict_ = {} + last_prefix, module_id = "", -1 + for name in name_list: + names = name.split(".") + prefix_index = names.index("temporal_transformer") + 1 + prefix = ".".join(names[:prefix_index]) + if prefix != last_prefix: + last_prefix = prefix + module_id += 1 + middle_name = ".".join(names[prefix_index:-1]) + suffix = names[-1] + if "pos_encoder" in names: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) + else: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) + state_dict_[rename] = state_dict[name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/sd_unet.py b/diffsynth/models/sd_unet.py index 999a103..1fd1f02 100644 --- a/diffsynth/models/sd_unet.py +++ b/diffsynth/models/sd_unet.py @@ -279,7 +279,7 @@ class SDUNet(torch.nn.Module): self.conv_act = torch.nn.SiLU() self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) - def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, **kwargs): + def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs): # 1. time time_emb = self.time_proj(timestep[None]).to(sample.dtype) time_emb = self.time_embedding(time_emb) @@ -293,6 +293,10 @@ class SDUNet(torch.nn.Module): # 3. blocks for i, block in enumerate(self.blocks): + if additional_res_stack is not None and i==31: + hidden_states += additional_res_stack.pop() + res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] + additional_res_stack = None if tiled: hidden_states, time_emb, text_emb, res_stack = self.tiled_inference( block, hidden_states, time_emb, text_emb, res_stack, diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 55d511d..9fe28d1 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -1,2 +1,3 @@ from .stable_diffusion import SDPipeline from .stable_diffusion_xl import SDXLPipeline +from .stable_diffusion_video import SDVideoPipeline diff --git a/diffsynth/pipelines/stable_diffusion.py b/diffsynth/pipelines/stable_diffusion.py index 5759304..9d5d962 100644 --- a/diffsynth/pipelines/stable_diffusion.py +++ b/diffsynth/pipelines/stable_diffusion.py @@ -1,4 +1,5 @@ -from ..models import ModelManager +from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder +from ..controlnets.controlnet_unit import MultiControlNetManager from ..prompts import SDPrompter from ..schedulers import EnhancedDDIMScheduler import torch @@ -9,9 +10,29 @@ import numpy as np class SDPipeline(torch.nn.Module): - def __init__(self): + def __init__(self, device="cuda", torch_dtype=torch.float16): super().__init__() self.scheduler = EnhancedDDIMScheduler() + self.prompter = SDPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDTextEncoder = None + self.unet: SDUNet = None + self.vae_decoder: SDVAEDecoder = None + self.vae_encoder: SDVAEEncoder = None + self.controlnet: MultiControlNetManager = None + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + # load textual inversion + self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) + + def fetch_controlnet_models(self, controlnet_units=[]): + self.controlnet = MultiControlNetManager(controlnet_units) def preprocess_image(self, image): image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) @@ -20,13 +41,12 @@ class SDPipeline(torch.nn.Module): @torch.no_grad() def __call__( self, - model_manager: ModelManager, - prompter: SDPrompter, prompt, negative_prompt="", cfg_scale=7.5, clip_skip=1, init_image=None, + controlnet_image=None, denoising_strength=1.0, height=512, width=512, @@ -38,37 +58,59 @@ class SDPipeline(torch.nn.Module): progress_bar_st=None, ): # Encode prompts - prompt_emb = prompter.encode_prompt(model_manager.text_encoder, prompt, clip_skip=clip_skip, device=model_manager.device) - negative_prompt_emb = prompter.encode_prompt(model_manager.text_encoder, negative_prompt, clip_skip=clip_skip, device=model_manager.device) + prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device) + prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device) # Prepare scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors if init_image is not None: - image = self.preprocess_image(init_image).to(device=model_manager.device, dtype=model_manager.torch_type) - latents = model_manager.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - noise = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type) + image = self.preprocess_image(init_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: - latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type) + latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + + # Prepare ControlNets + if controlnet_image is not None: + controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - timestep = torch.IntTensor((timestep,))[0].to(model_manager.device) + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # ControlNet + if controlnet_image is not None: + additional_res_stack_posi = self.controlnet(latents, timestep, prompt_emb_posi, controlnet_image) + additional_res_stack_nega = self.controlnet(latents, timestep, prompt_emb_nega, controlnet_image) + else: + additional_res_stack_posi = None + additional_res_stack_nega = None # Classifier-free guidance - noise_pred_cond = model_manager.unet(latents, timestep, prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - noise_pred_uncond = model_manager.unet(latents, timestep, negative_prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred_posi = self.unet( + latents, timestep, prompt_emb_posi, + additional_res_stack=additional_res_stack_posi, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + noise_pred_nega = self.unet( + latents, timestep, prompt_emb_nega, + additional_res_stack=additional_res_stack_nega, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + # DDIM latents = self.scheduler.step(noise_pred, timestep, latents) + # UI if progress_bar_st is not None: progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image - image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = self.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] image = image.cpu().permute(1, 2, 0).numpy() image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py new file mode 100644 index 0000000..64e4783 --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -0,0 +1,302 @@ +from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel +from ..models.sd_unet import PushBlock, PopBlock +from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator +from ..prompts import SDPrompter +from ..schedulers import EnhancedDDIMScheduler +from typing import List +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + +def lets_dance( + unet: SDUNet, + motion_modules: SDMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + timestep = None, + encoder_hidden_states = None, + controlnet_frames = None, + unet_batch_size = 1, + controlnet_batch_size = 1, + device = "cuda", + vram_limit_level = 0, +): + # 1. ControlNet + # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. + # I leave it here because I intend to do something interesting on the ControlNets. + controlnet_insert_block_id = 30 + if controlnet is not None and controlnet_frames is not None: + res_stacks = [] + # process controlnet frames with batch + for batch_id in range(0, sample.shape[0], controlnet_batch_size): + batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) + res_stack = controlnet( + sample[batch_id: batch_id_], + timestep, + encoder_hidden_states[batch_id: batch_id_], + controlnet_frames[:, batch_id: batch_id_] + ) + if vram_limit_level >= 1: + res_stack = [res.cpu() for res in res_stack] + res_stacks.append(res_stack) + # concat the residual + additional_res_stack = [] + for i in range(len(res_stacks[0])): + res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) + additional_res_stack.append(res) + else: + additional_res_stack = None + + # 2. time + time_emb = unet.time_proj(timestep[None]).to(sample.dtype) + time_emb = unet.time_embedding(time_emb) + + # 3. pre-process + hidden_states = unet.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] + + # 4. blocks + for block_id, block in enumerate(unet.blocks): + # 4.1 UNet + if isinstance(block, PushBlock): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].cpu() + elif isinstance(block, PopBlock): + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].to(device) + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + else: + hidden_states_input = hidden_states + hidden_states_output = [] + for batch_id in range(0, sample.shape[0], unet_batch_size): + batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) + hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack) + hidden_states_output.append(hidden_states) + hidden_states = torch.concat(hidden_states_output, dim=0) + # 4.2 AnimateDiff + if motion_modules is not None: + if block_id in motion_modules.call_block_id: + motion_module_id = motion_modules.call_block_id[block_id] + hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( + hidden_states, time_emb, text_emb, res_stack, + batch_size=1 + ) + # 4.3 ControlNet + if block_id == controlnet_insert_block_id and additional_res_stack is not None: + hidden_states += additional_res_stack.pop().to(device) + if vram_limit_level>=1: + res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] + else: + res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] + + # 5. output + hidden_states = unet.conv_norm_out(hidden_states) + hidden_states = unet.conv_act(hidden_states) + hidden_states = unet.conv_out(hidden_states) + + return hidden_states + + +def lets_dance_with_long_video( + unet: SDUNet, + motion_modules: SDMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + timestep = None, + encoder_hidden_states = None, + controlnet_frames = None, + unet_batch_size = 1, + controlnet_batch_size = 1, + animatediff_batch_size = 16, + animatediff_stride = 8, + device = "cuda", + vram_limit_level = 0, +): + num_frames = sample.shape[0] + hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)] + + for batch_id in range(0, num_frames, animatediff_stride): + batch_id_ = min(batch_id + animatediff_batch_size, num_frames) + + # process this batch + hidden_states_batch = lets_dance( + unet, motion_modules, controlnet, + sample[batch_id: batch_id_].to(device), + timestep, + encoder_hidden_states[batch_id: batch_id_].to(device), + controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, device=device, vram_limit_level=vram_limit_level + ).cpu() + + # update hidden_states + for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch): + bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1) / 2), 1e-2) + hidden_states, num = hidden_states_output[i] + hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias)) + hidden_states_output[i] = (hidden_states, num + 1) + + # output + hidden_states = torch.stack([h for h, _ in hidden_states_output]) + return hidden_states + + +class SDVideoPipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True): + super().__init__() + self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear") + self.prompter = SDPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDTextEncoder = None + self.unet: SDUNet = None + self.vae_decoder: SDVAEDecoder = None + self.vae_encoder: SDVAEEncoder = None + self.controlnet: MultiControlNetManager = None + self.motion_modules: SDMotionModel = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + # load textual inversion + self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) + + + def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + controlnet_units = [] + for config in controlnet_config_units: + controlnet_unit = ControlNetUnit( + Annotator(config.processor_id), + model_manager.get_model_with_model_path(config.model_path), + config.scale + ) + controlnet_units.append(controlnet_unit) + self.controlnet = MultiControlNetManager(controlnet_units) + + + def fetch_motion_modules(self, model_manager: ModelManager): + if "motion_modules" in model_manager.model: + self.motion_modules = model_manager.motion_modules + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + pipe = SDVideoPipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + use_animatediff="motion_modules" in model_manager.model + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_motion_modules(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32): + images = [ + self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + for frame_id in range(latents.shape[0]) + ] + return images + + + def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32): + latents = [] + for image in processed_images: + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu() + latents.append(latent) + latents = torch.concat(latents, dim=0) + return latents + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + num_frames=None, + input_frames=None, + controlnet_frames=None, + denoising_strength=1.0, + height=512, + width=512, + num_inference_steps=20, + vram_limit_level=0, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Encode prompts + prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device).cpu() + prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device).cpu() + prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1) + prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1) + + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype) + if input_frames is None: + latents = noise + else: + latents = self.encode_images(input_frames) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + + # Prepare ControlNets + if controlnet_frames is not None: + controlnet_frames = torch.stack([ + self.controlnet.process_image(controlnet_frame).to(self.torch_dtype) + for controlnet_frame in progress_bar_cmd(controlnet_frames) + ], dim=1) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance_with_long_video( + self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred_nega = lets_dance_with_long_video( + self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + + # DDIM + latents = self.scheduler.step(noise_pred, timestep, latents) + + # UI + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + output_frames = self.decode_images(latents) + + return output_frames diff --git a/diffsynth/prompts/__init__.py b/diffsynth/prompts/__init__.py index f774623..af4e2d0 100644 --- a/diffsynth/prompts/__init__.py +++ b/diffsynth/prompts/__init__.py @@ -1,7 +1,6 @@ from transformers import CLIPTokenizer -from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2 +from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict import torch, os -from safetensors import safe_open def tokenize_long_prompt(tokenizer, prompt): @@ -36,49 +35,40 @@ def tokenize_long_prompt(tokenizer, prompt): return input_ids -def load_textual_inversion(prompt): - # TODO: This module is not enabled now. - textual_inversion_files = os.listdir("models/textual_inversion") - embeddings_768 = [] - embeddings_1280 = [] - for file_name in textual_inversion_files: - if not file_name.endswith(".safetensors"): - continue - keyword = file_name[:-len(".safetensors")] - if keyword in prompt: - prompt = prompt.replace(keyword, "") - with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f: - for k in f.keys(): - embedding = f.get_tensor(k).to(torch.float32) - if embedding.shape[-1] == 768: - embeddings_768.append(embedding) - elif embedding.shape[-1] == 1280: - embeddings_1280.append(embedding) - - if len(embeddings_768)==0: - embeddings_768 = torch.zeros((0, 768)) - else: - embeddings_768 = torch.concat(embeddings_768, dim=0) - - if len(embeddings_1280)==0: - embeddings_1280 = torch.zeros((0, 1280)) - else: - embeddings_1280 = torch.concat(embeddings_1280, dim=0) - - return prompt, embeddings_768, embeddings_1280 +def search_for_embeddings(state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += search_for_embeddings(state_dict[k]) + return embeddings class SDPrompter: def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"): # We use the tokenizer implemented by transformers self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + self.keyword_dict = {} def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"): + for keyword in self.keyword_dict: + if keyword in prompt: + prompt = prompt.replace(keyword, self.keyword_dict[keyword]) input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) return prompt_emb + + def load_textual_inversion(self, textual_inversion_dict): + self.keyword_dict = {} + additional_tokens = [] + for keyword in textual_inversion_dict: + tokens, _ = textual_inversion_dict[keyword] + additional_tokens += tokens + self.keyword_dict[keyword] = " " + " ".join(tokens) + " " + self.tokenizer.add_tokens(additional_tokens) class SDXLPrompter: diff --git a/diffsynth/schedulers/__init__.py b/diffsynth/schedulers/__init__.py index 53d1dfc..a9336ff 100644 --- a/diffsynth/schedulers/__init__.py +++ b/diffsynth/schedulers/__init__.py @@ -3,9 +3,14 @@ import torch, math class EnhancedDDIMScheduler(): - def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012): + def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"): self.num_train_timesteps = num_train_timesteps - betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) + if beta_schedule == "scaled_linear": + betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) + elif beta_schedule == "linear": + betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented") self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() self.set_timesteps(10) @@ -34,14 +39,14 @@ class EnhancedDDIMScheduler(): return prev_sample - def step(self, model_output, timestep, sample): + def step(self, model_output, timestep, sample, to_final=False): alpha_prod_t = self.alphas_cumprod[timestep] timestep_id = self.timesteps.index(timestep) - if timestep_id + 1 < len(self.timesteps): + if to_final or timestep_id + 1 >= len(self.timesteps): + alpha_prod_t_prev = 1.0 + else: timestep_prev = self.timesteps[timestep_id + 1] alpha_prod_t_prev = self.alphas_cumprod[timestep_prev] - else: - alpha_prod_t_prev = 1.0 return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..bb3b5af --- /dev/null +++ b/environment.yml @@ -0,0 +1,19 @@ +name: DiffSynthStudio +channels: + - pytorch + - nvidia + - defaults +dependencies: + - python=3.9.16 + - pip=23.0.1 + - cudatoolkit + - pytorch + - pip: + - transformers + - controlnet-aux==0.0.7 + - streamlit + - streamlit-drawable-canvas + - imageio + - imageio[ffmpeg] + - safetensors + - einops