mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 13:58:15 +00:00
initial version
This commit is contained in:
264
Diffsynth_Studio.py
Normal file
264
Diffsynth_Studio.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch, os, io
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.prompts import SDXLPrompter, SDPrompter
|
||||
from diffsynth.pipelines import SDXLPipeline, SDPipeline
|
||||
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_model_list(folder):
|
||||
file_list = os.listdir(folder)
|
||||
file_list = [i for i in file_list if i.endswith(".safetensors")]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def detect_model_path(sd_model_path, sdxl_model_path):
|
||||
if sd_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion", sd_model_path)
|
||||
elif sdxl_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
|
||||
else:
|
||||
model_path = None
|
||||
return model_path
|
||||
|
||||
|
||||
def load_model(sd_model_path, sdxl_model_path):
|
||||
if sd_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion", sd_model_path)
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors(model_path)
|
||||
prompter = SDPrompter()
|
||||
pipeline = SDPipeline()
|
||||
elif sdxl_model_path != "None":
|
||||
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors(model_path)
|
||||
prompter = SDXLPrompter()
|
||||
pipeline = SDXLPipeline()
|
||||
else:
|
||||
return None, None, None, None
|
||||
return model_path, model_manager, prompter, pipeline
|
||||
|
||||
|
||||
def release_model():
|
||||
if "model_manager" in st.session_state:
|
||||
st.session_state["model_manager"].to("cpu")
|
||||
del st.session_state["loaded_model_path"]
|
||||
del st.session_state["model_manager"]
|
||||
del st.session_state["prompter"]
|
||||
del st.session_state["pipeline"]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def use_output_image_as_input():
|
||||
# Search for input image
|
||||
output_image_id = 0
|
||||
selected_output_image = None
|
||||
while True:
|
||||
if f"use_output_as_input_{output_image_id}" not in st.session_state:
|
||||
break
|
||||
if st.session_state[f"use_output_as_input_{output_image_id}"]:
|
||||
selected_output_image = st.session_state["output_images"][output_image_id]
|
||||
break
|
||||
output_image_id += 1
|
||||
if selected_output_image is not None:
|
||||
st.session_state["input_image"] = selected_output_image
|
||||
|
||||
|
||||
def apply_stroke_to_image(stroke_image, image):
|
||||
image = np.array(image.convert("RGB")).astype(np.float32)
|
||||
height, width, _ = image.shape
|
||||
|
||||
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
|
||||
weight = stroke_image[:, :, -1:] / 255
|
||||
stroke_image = stroke_image[:, :, :-1]
|
||||
|
||||
image = stroke_image * weight + image * (1 - weight)
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def image2bits(image):
|
||||
image_byte = io.BytesIO()
|
||||
image.save(image_byte, format="PNG")
|
||||
image_byte = image_byte.getvalue()
|
||||
return image_byte
|
||||
|
||||
|
||||
def show_output_image(image):
|
||||
st.image(image, use_column_width="always")
|
||||
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
|
||||
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
|
||||
|
||||
|
||||
column_input, column_output = st.columns(2)
|
||||
|
||||
# with column_input:
|
||||
with st.sidebar:
|
||||
# Select a model
|
||||
with st.expander("Model", expanded=True):
|
||||
sd_model_list = ["None"] + load_model_list("models/stable_diffusion")
|
||||
sd_model_path = st.selectbox(
|
||||
"Stable Diffusion", sd_model_list
|
||||
)
|
||||
sdxl_model_list = ["None"] + load_model_list("models/stable_diffusion_xl")
|
||||
sdxl_model_path = st.selectbox(
|
||||
"Stable Diffusion XL", sdxl_model_list
|
||||
)
|
||||
|
||||
# Load the model
|
||||
model_path = detect_model_path(sd_model_path, sdxl_model_path)
|
||||
if model_path is None:
|
||||
st.markdown("No models selected.")
|
||||
release_model()
|
||||
elif st.session_state.get("loaded_model_path", "") != model_path:
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
release_model()
|
||||
model_path, model_manager, prompter, pipeline = load_model(sd_model_path, sdxl_model_path)
|
||||
st.session_state.loaded_model_path = model_path
|
||||
st.session_state.model_manager = model_manager
|
||||
st.session_state.prompter = prompter
|
||||
st.session_state.pipeline = pipeline
|
||||
else:
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
model_path, model_manager, prompter, pipeline = (
|
||||
st.session_state.loaded_model_path,
|
||||
st.session_state.model_manager,
|
||||
st.session_state.prompter,
|
||||
st.session_state.pipeline,
|
||||
)
|
||||
|
||||
# Show parameters
|
||||
with st.expander("Prompt", expanded=True):
|
||||
column_positive, column_negative = st.columns(2)
|
||||
prompt = st.text_area("Positive prompt")
|
||||
negative_prompt = st.text_area("Negative prompt")
|
||||
with st.expander("Classifier-free guidance", expanded=True):
|
||||
use_cfg = st.checkbox("Use classifier-free guidance", value=True)
|
||||
if use_cfg:
|
||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, step=0.1, value=7.5)
|
||||
else:
|
||||
cfg_scale = 1.0
|
||||
with st.expander("Inference steps", expanded=True):
|
||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20, label_visibility="hidden")
|
||||
with st.expander("Image size", expanded=True):
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
with st.expander("Seed", expanded=True):
|
||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
||||
if use_fixed_seed:
|
||||
seed = st.number_input("Random seed", value=0, label_visibility="hidden")
|
||||
with st.expander("Number of images", expanded=True):
|
||||
num_images = st.number_input("Number of images", value=4, label_visibility="hidden")
|
||||
with st.expander("Tile (for high resolution)", expanded=True):
|
||||
tiled = st.checkbox("Use tile", value=False)
|
||||
tile_size = st.select_slider("Tile size", options=[64, 128], value=64)
|
||||
tile_stride = st.select_slider("Tile stride", options=[8, 16, 32, 64], value=32)
|
||||
|
||||
|
||||
# Show input image
|
||||
with column_input:
|
||||
with st.expander("Input image (Optional)", expanded=True):
|
||||
with st.container(border=True):
|
||||
column_white_board, column_upload_image = st.columns([1, 2])
|
||||
with column_white_board:
|
||||
create_white_board = st.button("Create white board")
|
||||
delete_input_image = st.button("Delete input image")
|
||||
with column_upload_image:
|
||||
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
|
||||
|
||||
if upload_image is not None:
|
||||
st.session_state["input_image"] = Image.open(upload_image)
|
||||
elif create_white_board:
|
||||
st.session_state["input_image"] = Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255)
|
||||
else:
|
||||
use_output_image_as_input()
|
||||
|
||||
if delete_input_image and "input_image" in st.session_state:
|
||||
del st.session_state.input_image
|
||||
if delete_input_image and "upload_image" in st.session_state:
|
||||
del st.session_state.upload_image
|
||||
|
||||
input_image = st.session_state.get("input_image", None)
|
||||
if input_image is not None:
|
||||
with st.container(border=True):
|
||||
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
|
||||
with column_drawing_mode:
|
||||
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
|
||||
with column_color_1:
|
||||
stroke_color = st.color_picker("Stroke color")
|
||||
with column_color_2:
|
||||
fill_color = st.color_picker("Fill color")
|
||||
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
|
||||
with st.container(border=True):
|
||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
|
||||
with st.container(border=True):
|
||||
input_width, input_height = input_image.size
|
||||
canvas_result = st_canvas(
|
||||
fill_color=fill_color,
|
||||
stroke_width=stroke_width,
|
||||
stroke_color=stroke_color,
|
||||
background_color="rgba(255, 255, 255, 0)",
|
||||
background_image=input_image,
|
||||
update_streamlit=True,
|
||||
height=int(512 / input_width * input_height),
|
||||
width=512,
|
||||
drawing_mode=drawing_mode,
|
||||
key="canvas"
|
||||
)
|
||||
|
||||
|
||||
with column_output:
|
||||
run_button = st.button("Generate image", type="primary")
|
||||
auto_update = st.checkbox("Auto update", value=False)
|
||||
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
|
||||
image_columns = st.columns(num_image_columns)
|
||||
|
||||
# Run
|
||||
if (run_button or auto_update) and model_path is not None:
|
||||
|
||||
if not use_fixed_seed:
|
||||
torch.manual_seed(np.random.randint(0, 10**9))
|
||||
|
||||
output_images = []
|
||||
for image_id in range(num_images):
|
||||
if use_fixed_seed:
|
||||
torch.manual_seed(seed + image_id)
|
||||
if input_image is not None:
|
||||
input_image = input_image.resize((width, height))
|
||||
if canvas_result.image_data is not None:
|
||||
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
|
||||
else:
|
||||
denoising_strength = 1.0
|
||||
with image_columns[image_id % num_image_columns]:
|
||||
progress_bar = st.progress(0.0)
|
||||
image = pipeline(
|
||||
model_manager, prompter,
|
||||
prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
height=height, width=width,
|
||||
init_image=input_image, denoising_strength=denoising_strength,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
progress_bar_st=progress_bar
|
||||
)
|
||||
output_images.append(image)
|
||||
progress_bar.progress(1.0)
|
||||
show_output_image(image)
|
||||
st.session_state["output_images"] = output_images
|
||||
|
||||
elif "output_images" in st.session_state:
|
||||
for image_id in range(len(st.session_state.output_images)):
|
||||
with image_columns[image_id % num_image_columns]:
|
||||
image = st.session_state.output_images[image_id]
|
||||
progress_bar = st.progress(1.0)
|
||||
show_output_image(image)
|
||||
2
LICENSE
2
LICENSE
@@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright [2023] [Zhongjie Duan]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
53
README-zh.md
Normal file
53
README-zh.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# DiffSynth Studio
|
||||
|
||||
## 介绍
|
||||
|
||||
DiffSynth 是一个全新的 Diffusion 引擎,我们重构了 Text Encoder、UNet、VAE 等架构,保持与开源社区模型兼容性的同时,提升了计算性能。目前这个版本仅仅是一个初始版本,实现了文生图和图生图功能,支持 SD 和 SDXL 架构。未来我们计划基于这个全新的代码库开发更多有趣的功能。
|
||||
|
||||
## 安装
|
||||
|
||||
如果你只想在 Python 代码层面调用 DiffSynth Studio,你只需要安装 `torch`(深度学习框架)和 `transformers`(仅用于实现分词器)。
|
||||
|
||||
```
|
||||
pip install torch transformers
|
||||
```
|
||||
|
||||
如果你想使用 UI,还需要额外安装 `streamlit`(一个 webui 框架)和 `streamlit-drawable-canvas`(用于图生图画板)。
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
通过 Python 代码调用
|
||||
|
||||
```python
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.prompts import SDPrompter, SDXLPrompter
|
||||
from diffsynth.pipelines import SDPipeline, SDXLPipeline
|
||||
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
|
||||
prompter = SDPrompter()
|
||||
pipe = SDPipeline()
|
||||
|
||||
prompt = "a girl"
|
||||
negative_prompt = ""
|
||||
|
||||
image = pipe(
|
||||
model_manager, prompter,
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=20, height=512, width=512,
|
||||
)
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
如果需要用 SDXL 架构模型,请把 `SDPrompter`、`SDPipeline` 换成 `SDXLPrompter`, `SDXLPipeline`。
|
||||
|
||||
当然,你也可以使用我们提供的 UI,但请注意,我们的 UI 程序很简单,且未来可能会大幅改变。
|
||||
|
||||
```
|
||||
python -m streamlit run Diffsynth_Studio.py
|
||||
```
|
||||
53
README.md
Normal file
53
README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# DiffSynth Studio
|
||||
|
||||
## 介绍
|
||||
|
||||
DiffSynth is a new Diffusion engine. We have restructured architectures like Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting text-to-image and image-to-image functionalities, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
|
||||
|
||||
## 安装
|
||||
|
||||
If you only want to use DiffSynth Studio at the Python code level, you just need to install torch (a deep learning framework) and transformers (only used for implementing a tokenizer).
|
||||
|
||||
```
|
||||
pip install torch transformers
|
||||
```
|
||||
|
||||
If you wish to use the UI, you'll also need to additionally install `streamlit` (a web UI framework) and `streamlit-drawable-canvas` (used for the image-to-image canvas).
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
Use DiffSynth Studio in Python
|
||||
|
||||
```python
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.prompts import SDPrompter, SDXLPrompter
|
||||
from diffsynth.pipelines import SDPipeline, SDXLPipeline
|
||||
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
|
||||
prompter = SDPrompter()
|
||||
pipe = SDPipeline()
|
||||
|
||||
prompt = "a girl"
|
||||
negative_prompt = ""
|
||||
|
||||
image = pipe(
|
||||
model_manager, prompter,
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=20, height=512, width=512,
|
||||
)
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
If you want to use SDXL architecture models, replace `SDPrompter` and `SDPipeline` with `SDXLPrompter` and `SDXLPipeline`, respectively.
|
||||
|
||||
Of course, you can also use the UI we provide. The UI is simple but may be changed in the future.
|
||||
|
||||
```
|
||||
python -m streamlit run Diffsynth_Studio.py
|
||||
```
|
||||
48895
configs/stable_diffusion/tokenizer/merges.txt
Normal file
48895
configs/stable_diffusion/tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
24
configs/stable_diffusion/tokenizer/special_tokens_map.json
Normal file
24
configs/stable_diffusion/tokenizer/special_tokens_map.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": "<|endoftext|>",
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
34
configs/stable_diffusion/tokenizer/tokenizer_config.json
Normal file
34
configs/stable_diffusion/tokenizer/tokenizer_config.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"do_lower_case": true,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"name_or_path": "openai/clip-vit-large-patch14",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"special_tokens_map_file": "./special_tokens_map.json",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
49410
configs/stable_diffusion/tokenizer/vocab.json
Normal file
49410
configs/stable_diffusion/tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
40213
configs/stable_diffusion_xl/tokenizer_2/merges.txt
Normal file
40213
configs/stable_diffusion_xl/tokenizer_2/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": "!",
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
49411
configs/stable_diffusion_xl/tokenizer_2/vocab.json
Normal file
49411
configs/stable_diffusion_xl/tokenizer_2/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
126
diffsynth/models/__init__.py
Normal file
126
diffsynth/models/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_type=torch.float16, device="cuda"):
|
||||
self.torch_type = torch_type
|
||||
self.device = device
|
||||
self.model = {}
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
return True
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
"unet": SDUNet,
|
||||
"vae_decoder": SDVAEDecoder,
|
||||
"vae_encoder": SDVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_type).to(self.device)
|
||||
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None):
|
||||
component_dict = {
|
||||
"text_encoder": SDXLTextEncoder,
|
||||
"text_encoder_2": SDXLTextEncoder2,
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component in ["vae_decoder", "vae_encoder"]:
|
||||
# These two model will output nan when float16 is enabled.
|
||||
# The precision problem happens in the last three resnet blocks.
|
||||
# I do not know how to solve this problem.
|
||||
self.model[component].to(torch.float32).to(self.device)
|
||||
else:
|
||||
self.model[component].to(self.torch_type).to(self.device)
|
||||
|
||||
def load_from_safetensors(self, file_path, components=None):
|
||||
state_dict = load_state_dict_from_safetensors(file_path)
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
self.model[component].to(device)
|
||||
|
||||
def __getattr__(self, __name):
|
||||
if __name in self.model:
|
||||
return self.model[__name]
|
||||
else:
|
||||
return super.__getattribute__(__name)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path):
|
||||
return torch.load(file_path, map_location="cpu")
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
38
diffsynth/models/attention.py
Normal file
38
diffsynth/models/attention.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
320
diffsynth/models/sd_text_encoder.py
Normal file
320
diffsynth/models/sd_text_encoder.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||
super().__init__()
|
||||
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
|
||||
def quickGELU(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward(self, hidden_states, attn_mask):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
if self.use_quick_gelu:
|
||||
hidden_states = self.quickGELU(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
1090
diffsynth/models/sd_unet.py
Normal file
1090
diffsynth/models/sd_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
330
diffsynth/models/sd_vae_decoder.py
Normal file
330
diffsynth/models/sd_vae_decoder.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import Tiler
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SDVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
|
||||
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
sample = sample / self.scaling_factor
|
||||
hidden_states = self.post_quant_conv(sample)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"post_quant_conv": "post_quant_conv",
|
||||
"decoder.conv_in": "conv_in",
|
||||
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
||||
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
||||
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
||||
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
||||
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
||||
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
||||
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
||||
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
||||
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
||||
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
||||
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
||||
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
||||
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
||||
"decoder.conv_norm_out": "conv_norm_out",
|
||||
"decoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("decoder.up_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
|
||||
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
258
diffsynth/models/sd_vae_encoder.py
Normal file
258
diffsynth/models/sd_vae_encoder.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import Tiler
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDVAEEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"quant_conv": "quant_conv",
|
||||
"encoder.conv_in": "conv_in",
|
||||
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
||||
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
||||
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
||||
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
||||
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
||||
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
||||
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
||||
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
||||
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
||||
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
||||
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
||||
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
||||
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
||||
"encoder.conv_norm_out": "conv_norm_out",
|
||||
"encoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("encoder.down_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.quant_conv.bias": "quant_conv.bias",
|
||||
"first_stage_model.quant_conv.weight": "quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
757
diffsynth/models/sdxl_text_encoder.py
Normal file
757
diffsynth/models/sdxl_text_encoder.py
Normal file
@@ -0,0 +1,757 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class SDXLTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# The text encoder is different to that in Stable Diffusion 1.x.
|
||||
# It does not include final_layer_norm.
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLTextEncoder2(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
# text_projection
|
||||
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
pooled_embeds = self.text_projection(pooled_embeds)
|
||||
return pooled_embeds, hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
class SDXLTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class SDXLTextEncoder2StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"text_projection.weight": "text_projection.weight"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
|
||||
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
|
||||
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
|
||||
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.1.model.positional_embedding":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
elif name == "conditioner.embedders.1.model.text_projection":
|
||||
param = param.T
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
length = param.shape[0] // 3
|
||||
for i, rename in enumerate(rename_dict[name]):
|
||||
state_dict_[rename] = param[i*length: i*length+length]
|
||||
return state_dict_
|
||||
1897
diffsynth/models/sdxl_unet.py
Normal file
1897
diffsynth/models/sdxl_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
15
diffsynth/models/sdxl_vae_decoder.py
Normal file
15
diffsynth/models/sdxl_vae_decoder.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class SDXLVAEDecoder(SDVAEDecoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
15
diffsynth/models/sdxl_vae_encoder.py
Normal file
15
diffsynth/models/sdxl_vae_encoder.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SDXLVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
75
diffsynth/models/tiler.py
Normal file
75
diffsynth/models/tiler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Tiler(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def mask(self, height, width, line_width):
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / line_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
def forward(self, forward_fn, x, tile_size, tile_stride, batch_size=1, inter_device="cpu", inter_dtype=torch.float32):
|
||||
# Prepare
|
||||
device = x.device
|
||||
torch_dtype = x.dtype
|
||||
|
||||
# tile
|
||||
b, c_in, h_in, w_in = x.shape
|
||||
x = x.to(device=inter_device, dtype=inter_dtype)
|
||||
fold_params = {
|
||||
"kernel_size": (tile_size, tile_size),
|
||||
"stride": (tile_stride, tile_stride)
|
||||
}
|
||||
unfold_operator = torch.nn.Unfold(**fold_params)
|
||||
x = unfold_operator(x)
|
||||
x = x.view((b, c_in, tile_size, tile_size, -1))
|
||||
|
||||
# inference
|
||||
x_out_stack = []
|
||||
for tile_id in range(0, x.shape[-1], batch_size):
|
||||
|
||||
# process input
|
||||
next_tile_id = min(tile_id + batch_size, x.shape[-1])
|
||||
x_in = x[:, :, :, :, tile_id: next_tile_id]
|
||||
x_in = x_in.to(device=device, dtype=torch_dtype)
|
||||
x_in = x_in.permute(4, 0, 1, 2, 3)
|
||||
x_in = x_in.view((x_in.shape[0]*x_in.shape[1], x_in.shape[2], x_in.shape[3], x_in.shape[4]))
|
||||
|
||||
# process output
|
||||
x_out = forward_fn(x_in)
|
||||
x_out = x_out.view((next_tile_id - tile_id, b, x_out.shape[1], x_out.shape[2], x_out.shape[3]))
|
||||
x_out = x_out.permute(1, 2, 3, 4, 0)
|
||||
x_out = x_out.to(device=inter_device, dtype=inter_dtype)
|
||||
x_out_stack.append(x_out)
|
||||
|
||||
x = torch.concat(x_out_stack, dim=-1)
|
||||
|
||||
# untile
|
||||
in2out_scale = x.shape[2] / tile_size
|
||||
h_out, w_out = int(h_in * in2out_scale), int(w_in * in2out_scale)
|
||||
|
||||
mask = self.mask(int(tile_size * in2out_scale), int(tile_size * in2out_scale), int(tile_stride * in2out_scale * 0.5))
|
||||
mask = mask.to(device=inter_device, dtype=inter_dtype)
|
||||
mask = mask.reshape((1, 1, mask.shape[0], mask.shape[1], 1))
|
||||
x = x * mask
|
||||
|
||||
fold_params = {
|
||||
"kernel_size": (int(tile_size * in2out_scale), int(tile_size * in2out_scale)),
|
||||
"stride": (int(tile_stride * in2out_scale), int(tile_stride * in2out_scale))
|
||||
}
|
||||
fold_operator = torch.nn.Fold(output_size=(h_out, w_out), **fold_params)
|
||||
divisor = fold_operator(mask.repeat(1, 1, 1, 1, x.shape[-1]).view(b, -1, x.shape[-1]))
|
||||
|
||||
x = x.view((b, -1, x.shape[-1]))
|
||||
x = fold_operator(x) / divisor
|
||||
x = x.to(device=device, dtype=torch_dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
2
diffsynth/pipelines/__init__.py
Normal file
2
diffsynth/pipelines/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .stable_diffusion import SDPipeline
|
||||
from .stable_diffusion_xl import SDXLPipeline
|
||||
75
diffsynth/pipelines/stable_diffusion.py
Normal file
75
diffsynth/pipelines/stable_diffusion.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from ..models import ModelManager
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
prompter: SDPrompter,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
init_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
prompt_emb = prompter.encode_prompt(model_manager.text_encoder, prompt, clip_skip=clip_skip, device=model_manager.device)
|
||||
negative_prompt_emb = prompter.encode_prompt(model_manager.text_encoder, negative_prompt, clip_skip=clip_skip, device=model_manager.device)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if init_image is not None:
|
||||
image = self.preprocess_image(init_image).to(device=model_manager.device, dtype=model_manager.torch_type)
|
||||
latents = model_manager.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_cond = model_manager.unet(latents, timestep, prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise_pred_uncond = model_manager.unet(latents, timestep, negative_prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
|
||||
return image
|
||||
126
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
126
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from ..models import ModelManager
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
prompter: SDXLPrompter,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
init_image=None,
|
||||
denoising_strength=1.0,
|
||||
refining_strength=0.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
add_text_embeds, prompt_emb = prompter.encode_prompt(
|
||||
model_manager.text_encoder,
|
||||
model_manager.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=model_manager.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt(
|
||||
model_manager.text_encoder,
|
||||
model_manager.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=model_manager.device
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if init_image is not None:
|
||||
image = self.preprocess_image(init_image).to(
|
||||
device=model_manager.device, dtype=model_manager.torch_type
|
||||
)
|
||||
latents = model_manager.vae_encoder(
|
||||
image.to(torch.float32),
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise = torch.randn(
|
||||
(1, 4, height//8, width//8),
|
||||
device=model_manager.device, dtype=model_manager.torch_type
|
||||
)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents.to(model_manager.torch_type),
|
||||
noise,
|
||||
timestep=self.scheduler.timesteps[0]
|
||||
)
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
if timestep >= 1000 * refining_strength:
|
||||
denoising_model = model_manager.unet
|
||||
else:
|
||||
denoising_model = model_manager.refiner
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_cond = denoising_model(
|
||||
latents, timestep, prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise_pred_uncond = denoising_model(
|
||||
latents, timestep, negative_prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = denoising_model(
|
||||
latents, timestep, prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
latents = latents.to(torch.float32)
|
||||
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
|
||||
return image
|
||||
117
diffsynth/prompts/__init__.py
Normal file
117
diffsynth/prompts/__init__.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
def load_textual_inversion(prompt):
|
||||
# TODO: This module is not enabled now.
|
||||
textual_inversion_files = os.listdir("models/textual_inversion")
|
||||
embeddings_768 = []
|
||||
embeddings_1280 = []
|
||||
for file_name in textual_inversion_files:
|
||||
if not file_name.endswith(".safetensors"):
|
||||
continue
|
||||
keyword = file_name[:-len(".safetensors")]
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, "")
|
||||
with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
embedding = f.get_tensor(k).to(torch.float32)
|
||||
if embedding.shape[-1] == 768:
|
||||
embeddings_768.append(embedding)
|
||||
elif embedding.shape[-1] == 1280:
|
||||
embeddings_1280.append(embedding)
|
||||
|
||||
if len(embeddings_768)==0:
|
||||
embeddings_768 = torch.zeros((0, 768))
|
||||
else:
|
||||
embeddings_768 = torch.concat(embeddings_768, dim=0)
|
||||
|
||||
if len(embeddings_1280)==0:
|
||||
embeddings_1280 = torch.zeros((0, 1280))
|
||||
else:
|
||||
embeddings_1280 = torch.concat(embeddings_1280, dim=0)
|
||||
|
||||
return prompt, embeddings_768, embeddings_1280
|
||||
|
||||
|
||||
class SDPrompter:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
class SDXLPrompter:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
device="cuda"
|
||||
):
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
37
diffsynth/prompts/sd_tokenizer.py
Normal file
37
diffsynth/prompts/sd_tokenizer.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def __call__(self, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = self.tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
self.tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore self.tokenizer.model_max_length
|
||||
self.tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
45
diffsynth/prompts/sdxl_tokenizer.py
Normal file
45
diffsynth/prompts/sdxl_tokenizer.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from .sd_tokenizer import SDTokenizer
|
||||
|
||||
|
||||
class SDXLTokenizer(SDTokenizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class SDXLTokenizer2:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def __call__(self, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = self.tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
self.tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore self.tokenizer.model_max_length
|
||||
self.tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
60
diffsynth/schedulers/__init__.py
Normal file
60
diffsynth/schedulers/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch, math
|
||||
|
||||
|
||||
class EnhancedDDIMScheduler():
|
||||
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
|
||||
self.set_timesteps(10)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
||||
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||
# but I think this implementation is more reasonable in theory.
|
||||
max_timestep = round(self.num_train_timesteps * denoising_strength) - 1
|
||||
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||
if num_inference_steps == 1:
|
||||
self.timesteps = [max_timestep]
|
||||
else:
|
||||
step_length = max_timestep / (num_inference_steps - 1)
|
||||
self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
|
||||
|
||||
|
||||
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
|
||||
weight_e = -math.sqrt((1 - alpha_prod_t) / alpha_prod_t)
|
||||
weight_x = math.sqrt(1 / alpha_prod_t)
|
||||
|
||||
return prev_sample
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
timestep_id = self.timesteps.index(timestep)
|
||||
if timestep_id + 1 < len(self.timesteps):
|
||||
timestep_prev = self.timesteps[timestep_id + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||
else:
|
||||
alpha_prod_t_prev = 1.0
|
||||
|
||||
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
Reference in New Issue
Block a user