initial version

This commit is contained in:
Artiprocher
2023-12-08 01:03:30 +08:00
parent 5073cf6938
commit b459784171
30 changed files with 193803 additions and 1 deletions

264
Diffsynth_Studio.py Normal file
View File

@@ -0,0 +1,264 @@
import torch, os, io
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.prompts import SDXLPrompter, SDPrompter
from diffsynth.pipelines import SDXLPipeline, SDPipeline
torch.cuda.set_per_process_memory_fraction(0.999, 0)
@st.cache_data
def load_model_list(folder):
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors")]
file_list = sorted(file_list)
return file_list
def detect_model_path(sd_model_path, sdxl_model_path):
if sd_model_path != "None":
model_path = os.path.join("models/stable_diffusion", sd_model_path)
elif sdxl_model_path != "None":
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
else:
model_path = None
return model_path
def load_model(sd_model_path, sdxl_model_path):
if sd_model_path != "None":
model_path = os.path.join("models/stable_diffusion", sd_model_path)
model_manager = ModelManager()
model_manager.load_from_safetensors(model_path)
prompter = SDPrompter()
pipeline = SDPipeline()
elif sdxl_model_path != "None":
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
model_manager = ModelManager()
model_manager.load_from_safetensors(model_path)
prompter = SDXLPrompter()
pipeline = SDXLPipeline()
else:
return None, None, None, None
return model_path, model_manager, prompter, pipeline
def release_model():
if "model_manager" in st.session_state:
st.session_state["model_manager"].to("cpu")
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["prompter"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
def use_output_image_as_input():
# Search for input image
output_image_id = 0
selected_output_image = None
while True:
if f"use_output_as_input_{output_image_id}" not in st.session_state:
break
if st.session_state[f"use_output_as_input_{output_image_id}"]:
selected_output_image = st.session_state["output_images"][output_image_id]
break
output_image_id += 1
if selected_output_image is not None:
st.session_state["input_image"] = selected_output_image
def apply_stroke_to_image(stroke_image, image):
image = np.array(image.convert("RGB")).astype(np.float32)
height, width, _ = image.shape
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
weight = stroke_image[:, :, -1:] / 255
stroke_image = stroke_image[:, :, :-1]
image = stroke_image * weight + image * (1 - weight)
image = np.clip(image, 0, 255).astype(np.uint8)
image = Image.fromarray(image)
return image
@st.cache_data
def image2bits(image):
image_byte = io.BytesIO()
image.save(image_byte, format="PNG")
image_byte = image_byte.getvalue()
return image_byte
def show_output_image(image):
st.image(image, use_column_width="always")
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
column_input, column_output = st.columns(2)
# with column_input:
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
sd_model_list = ["None"] + load_model_list("models/stable_diffusion")
sd_model_path = st.selectbox(
"Stable Diffusion", sd_model_list
)
sdxl_model_list = ["None"] + load_model_list("models/stable_diffusion_xl")
sdxl_model_path = st.selectbox(
"Stable Diffusion XL", sdxl_model_list
)
# Load the model
model_path = detect_model_path(sd_model_path, sdxl_model_path)
if model_path is None:
st.markdown("No models selected.")
release_model()
elif st.session_state.get("loaded_model_path", "") != model_path:
st.markdown(f"Using model at {model_path}.")
release_model()
model_path, model_manager, prompter, pipeline = load_model(sd_model_path, sdxl_model_path)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.prompter = prompter
st.session_state.pipeline = pipeline
else:
st.markdown(f"Using model at {model_path}.")
model_path, model_manager, prompter, pipeline = (
st.session_state.loaded_model_path,
st.session_state.model_manager,
st.session_state.prompter,
st.session_state.pipeline,
)
# Show parameters
with st.expander("Prompt", expanded=True):
column_positive, column_negative = st.columns(2)
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
with st.expander("Classifier-free guidance", expanded=True):
use_cfg = st.checkbox("Use classifier-free guidance", value=True)
if use_cfg:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, step=0.1, value=7.5)
else:
cfg_scale = 1.0
with st.expander("Inference steps", expanded=True):
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20, label_visibility="hidden")
with st.expander("Image size", expanded=True):
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
with st.expander("Seed", expanded=True):
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", value=0, label_visibility="hidden")
with st.expander("Number of images", expanded=True):
num_images = st.number_input("Number of images", value=4, label_visibility="hidden")
with st.expander("Tile (for high resolution)", expanded=True):
tiled = st.checkbox("Use tile", value=False)
tile_size = st.select_slider("Tile size", options=[64, 128], value=64)
tile_stride = st.select_slider("Tile stride", options=[8, 16, 32, 64], value=32)
# Show input image
with column_input:
with st.expander("Input image (Optional)", expanded=True):
with st.container(border=True):
column_white_board, column_upload_image = st.columns([1, 2])
with column_white_board:
create_white_board = st.button("Create white board")
delete_input_image = st.button("Delete input image")
with column_upload_image:
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
if upload_image is not None:
st.session_state["input_image"] = Image.open(upload_image)
elif create_white_board:
st.session_state["input_image"] = Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255)
else:
use_output_image_as_input()
if delete_input_image and "input_image" in st.session_state:
del st.session_state.input_image
if delete_input_image and "upload_image" in st.session_state:
del st.session_state.upload_image
input_image = st.session_state.get("input_image", None)
if input_image is not None:
with st.container(border=True):
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
with column_drawing_mode:
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
with column_color_1:
stroke_color = st.color_picker("Stroke color")
with column_color_2:
fill_color = st.color_picker("Fill color")
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
with st.container(border=True):
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
with st.container(border=True):
input_width, input_height = input_image.size
canvas_result = st_canvas(
fill_color=fill_color,
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="rgba(255, 255, 255, 0)",
background_image=input_image,
update_streamlit=True,
height=int(512 / input_width * input_height),
width=512,
drawing_mode=drawing_mode,
key="canvas"
)
with column_output:
run_button = st.button("Generate image", type="primary")
auto_update = st.checkbox("Auto update", value=False)
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
image_columns = st.columns(num_image_columns)
# Run
if (run_button or auto_update) and model_path is not None:
if not use_fixed_seed:
torch.manual_seed(np.random.randint(0, 10**9))
output_images = []
for image_id in range(num_images):
if use_fixed_seed:
torch.manual_seed(seed + image_id)
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
else:
denoising_strength = 1.0
with image_columns[image_id % num_image_columns]:
progress_bar = st.progress(0.0)
image = pipeline(
model_manager, prompter,
prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale,
num_inference_steps=num_inference_steps,
height=height, width=width,
init_image=input_image, denoising_strength=denoising_strength,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
progress_bar_st=progress_bar
)
output_images.append(image)
progress_bar.progress(1.0)
show_output_image(image)
st.session_state["output_images"] = output_images
elif "output_images" in st.session_state:
for image_id in range(len(st.session_state.output_images)):
with image_columns[image_id % num_image_columns]:
image = st.session_state.output_images[image_id]
progress_bar = st.progress(1.0)
show_output_image(image)

View File

@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright [2023] [Zhongjie Duan]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

53
README-zh.md Normal file
View File

@@ -0,0 +1,53 @@
# DiffSynth Studio
## 介绍
DiffSynth 是一个全新的 Diffusion 引擎,我们重构了 Text Encoder、UNet、VAE 等架构,保持与开源社区模型兼容性的同时,提升了计算性能。目前这个版本仅仅是一个初始版本,实现了文生图和图生图功能,支持 SD 和 SDXL 架构。未来我们计划基于这个全新的代码库开发更多有趣的功能。
## 安装
如果你只想在 Python 代码层面调用 DiffSynth Studio你只需要安装 `torch`(深度学习框架)和 `transformers`(仅用于实现分词器)。
```
pip install torch transformers
```
如果你想使用 UI还需要额外安装 `streamlit`(一个 webui 框架)和 `streamlit-drawable-canvas`(用于图生图画板)。
```
pip install streamlit streamlit-drawable-canvas
```
## 使用
通过 Python 代码调用
```python
from diffsynth.models import ModelManager
from diffsynth.prompts import SDPrompter, SDXLPrompter
from diffsynth.pipelines import SDPipeline, SDXLPipeline
model_manager = ModelManager()
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
prompter = SDPrompter()
pipe = SDPipeline()
prompt = "a girl"
negative_prompt = ""
image = pipe(
model_manager, prompter,
prompt, negative_prompt=negative_prompt,
num_inference_steps=20, height=512, width=512,
)
image.save("image.png")
```
如果需要用 SDXL 架构模型,请把 `SDPrompter``SDPipeline` 换成 `SDXLPrompter`, `SDXLPipeline`
当然,你也可以使用我们提供的 UI但请注意我们的 UI 程序很简单,且未来可能会大幅改变。
```
python -m streamlit run Diffsynth_Studio.py
```

53
README.md Normal file
View File

@@ -0,0 +1,53 @@
# DiffSynth Studio
## 介绍
DiffSynth is a new Diffusion engine. We have restructured architectures like Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting text-to-image and image-to-image functionalities, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
## 安装
If you only want to use DiffSynth Studio at the Python code level, you just need to install torch (a deep learning framework) and transformers (only used for implementing a tokenizer).
```
pip install torch transformers
```
If you wish to use the UI, you'll also need to additionally install `streamlit` (a web UI framework) and `streamlit-drawable-canvas` (used for the image-to-image canvas).
```
pip install streamlit streamlit-drawable-canvas
```
## 使用
Use DiffSynth Studio in Python
```python
from diffsynth.models import ModelManager
from diffsynth.prompts import SDPrompter, SDXLPrompter
from diffsynth.pipelines import SDPipeline, SDXLPipeline
model_manager = ModelManager()
model_manager.load_from_safetensors("xxxxxxxx.safetensors")
prompter = SDPrompter()
pipe = SDPipeline()
prompt = "a girl"
negative_prompt = ""
image = pipe(
model_manager, prompter,
prompt, negative_prompt=negative_prompt,
num_inference_steps=20, height=512, width=512,
)
image.save("image.png")
```
If you want to use SDXL architecture models, replace `SDPrompter` and `SDPipeline` with `SDXLPrompter` and `SDXLPipeline`, respectively.
Of course, you can also use the UI we provide. The UI is simple but may be changed in the future.
```
python -m streamlit run Diffsynth_Studio.py
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": "<|endoftext|>",
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,34 @@
{
"add_prefix_space": false,
"bos_token": {
"__type": "AddedToken",
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"do_lower_case": true,
"eos_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"errors": "replace",
"model_max_length": 77,
"name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",
"tokenizer_class": "CLIPTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": "!",
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,38 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "!",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "!",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,126 @@
import torch
from safetensors import safe_open
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
class ModelManager:
def __init__(self, torch_type=torch.float16, device="cuda"):
self.torch_type = torch_type
self.device = device
self.model = {}
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
return True
def load_stable_diffusion(self, state_dict, components=None):
component_dict = {
"text_encoder": SDTextEncoder,
"unet": SDUNet,
"vae_decoder": SDVAEDecoder,
"vae_encoder": SDVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_type).to(self.device)
def load_stable_diffusion_xl(self, state_dict, components=None):
component_dict = {
"text_encoder": SDXLTextEncoder,
"text_encoder_2": SDXLTextEncoder2,
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
if component in ["vae_decoder", "vae_encoder"]:
# These two model will output nan when float16 is enabled.
# The precision problem happens in the last three resnet blocks.
# I do not know how to solve this problem.
self.model[component].to(torch.float32).to(self.device)
else:
self.model[component].to(self.torch_type).to(self.device)
def load_from_safetensors(self, file_path, components=None):
state_dict = load_state_dict_from_safetensors(file_path)
if self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components)
def to(self, device):
for component in self.model:
self.model[component].to(device)
def __getattr__(self, __name):
if __name in self.model:
return self.model[__name]
else:
return super.__getattribute__(__name)
def load_state_dict_from_safetensors(file_path):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
return state_dict
def load_state_dict_from_bin(file_path):
return torch.load(file_path, map_location="cpu")
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)

View File

@@ -0,0 +1,38 @@
import torch
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states

View File

@@ -0,0 +1,320 @@
import torch
from .attention import Attention
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SDTextEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=1):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
break
embeds = self.final_layer_norm(embeds)
return embeds
def state_dict_converter(self):
return SDTextEncoderStateDictConverter()
class SDTextEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
return state_dict_

1090
diffsynth/models/sd_unet.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,330 @@
import torch
from .attention import Attention
from .sd_unet import ResnetBlock, UpSampler
from .tiler import Tiler
class VAEAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.transformer_blocks = torch.nn.ModuleList([
Attention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
def forward(self, hidden_states, time_emb, text_emb, res_stack):
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class SDVAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.18215
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
UpSampler(256),
# UpDecoderBlock2D
ResnetBlock(256, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
sample = sample / self.scaling_factor
hidden_states = self.post_quant_conv(sample)
hidden_states = self.conv_in(hidden_states)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
def state_dict_converter(self):
return SDVAEDecoderStateDictConverter()
class SDVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
]
# Rename each parameter
local_rename_dict = {
"post_quant_conv": "post_quant_conv",
"decoder.conv_in": "conv_in",
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
"decoder.conv_norm_out": "conv_norm_out",
"decoder.conv_out": "conv_out",
}
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
name_prefix = ".".join(names[:-1])
if name_prefix in local_rename_dict:
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
elif name.startswith("decoder.up_blocks"):
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
block_type_with_id = ".".join(names[:5])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:5])
names = ["blocks", str(block_id[block_type])] + names[5:]
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -0,0 +1,258 @@
import torch
from .sd_unet import ResnetBlock, DownSampler
from .sd_vae_decoder import VAEAttentionBlock
from .tiler import Tiler
class SDVAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.18215
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# DownEncoderBlock2D
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
DownSampler(128, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(128, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
DownSampler(256, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(256, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
DownSampler(512, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = self.conv_in(sample)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = self.quant_conv(hidden_states)
hidden_states = hidden_states[:, :4]
hidden_states *= self.scaling_factor
return hidden_states
def state_dict_converter(self):
return SDVAEEncoderStateDictConverter()
class SDVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock', 'DownSampler',
'ResnetBlock', 'ResnetBlock',
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
]
# Rename each parameter
local_rename_dict = {
"quant_conv": "quant_conv",
"encoder.conv_in": "conv_in",
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
"encoder.conv_norm_out": "conv_norm_out",
"encoder.conv_out": "conv_out",
}
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
name_prefix = ".".join(names[:-1])
if name_prefix in local_rename_dict:
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
elif name.startswith("encoder.down_blocks"):
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
block_type_with_id = ".".join(names[:5])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:5])
names = ["blocks", str(block_id[block_type])] + names[5:]
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
"first_stage_model.quant_conv.bias": "quant_conv.bias",
"first_stage_model.quant_conv.weight": "quant_conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -0,0 +1,757 @@
import torch
from .sd_text_encoder import CLIPEncoderLayer
class SDXLTextEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# The text encoder is different to that in Stable Diffusion 1.x.
# It does not include final_layer_norm.
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=1):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
break
return embeds
def state_dict_converter(self):
return SDXLTextEncoderStateDictConverter()
class SDXLTextEncoder2(torch.nn.Module):
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
# text_projection
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=2):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
pooled_embeds = self.text_projection(pooled_embeds)
return pooled_embeds, hidden_states
def state_dict_converter(self):
return SDXLTextEncoder2StateDictConverter()
class SDXLTextEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
return state_dict_
class SDXLTextEncoder2StateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
"text_projection.weight": "text_projection.weight"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "conditioner.embedders.1.model.positional_embedding":
param = param.reshape((1, param.shape[0], param.shape[1]))
elif name == "conditioner.embedders.1.model.text_projection":
param = param.T
if isinstance(rename_dict[name], str):
state_dict_[rename_dict[name]] = param
else:
length = param.shape[0] // 3
for i, rename in enumerate(rename_dict[name]):
state_dict_[rename] = param[i*length: i*length+length]
return state_dict_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
class SDXLVAEDecoder(SDVAEDecoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
return SDXLVAEDecoderStateDictConverter()
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
super().__init__()

View File

@@ -0,0 +1,15 @@
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
class SDXLVAEEncoder(SDVAEEncoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
return SDXLVAEEncoderStateDictConverter()
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
super().__init__()

75
diffsynth/models/tiler.py Normal file
View File

@@ -0,0 +1,75 @@
import torch
class Tiler(torch.nn.Module):
def __init__(self):
super().__init__()
def mask(self, height, width, line_width):
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / line_width).clip(0, 1)
return mask
def forward(self, forward_fn, x, tile_size, tile_stride, batch_size=1, inter_device="cpu", inter_dtype=torch.float32):
# Prepare
device = x.device
torch_dtype = x.dtype
# tile
b, c_in, h_in, w_in = x.shape
x = x.to(device=inter_device, dtype=inter_dtype)
fold_params = {
"kernel_size": (tile_size, tile_size),
"stride": (tile_stride, tile_stride)
}
unfold_operator = torch.nn.Unfold(**fold_params)
x = unfold_operator(x)
x = x.view((b, c_in, tile_size, tile_size, -1))
# inference
x_out_stack = []
for tile_id in range(0, x.shape[-1], batch_size):
# process input
next_tile_id = min(tile_id + batch_size, x.shape[-1])
x_in = x[:, :, :, :, tile_id: next_tile_id]
x_in = x_in.to(device=device, dtype=torch_dtype)
x_in = x_in.permute(4, 0, 1, 2, 3)
x_in = x_in.view((x_in.shape[0]*x_in.shape[1], x_in.shape[2], x_in.shape[3], x_in.shape[4]))
# process output
x_out = forward_fn(x_in)
x_out = x_out.view((next_tile_id - tile_id, b, x_out.shape[1], x_out.shape[2], x_out.shape[3]))
x_out = x_out.permute(1, 2, 3, 4, 0)
x_out = x_out.to(device=inter_device, dtype=inter_dtype)
x_out_stack.append(x_out)
x = torch.concat(x_out_stack, dim=-1)
# untile
in2out_scale = x.shape[2] / tile_size
h_out, w_out = int(h_in * in2out_scale), int(w_in * in2out_scale)
mask = self.mask(int(tile_size * in2out_scale), int(tile_size * in2out_scale), int(tile_stride * in2out_scale * 0.5))
mask = mask.to(device=inter_device, dtype=inter_dtype)
mask = mask.reshape((1, 1, mask.shape[0], mask.shape[1], 1))
x = x * mask
fold_params = {
"kernel_size": (int(tile_size * in2out_scale), int(tile_size * in2out_scale)),
"stride": (int(tile_stride * in2out_scale), int(tile_stride * in2out_scale))
}
fold_operator = torch.nn.Fold(output_size=(h_out, w_out), **fold_params)
divisor = fold_operator(mask.repeat(1, 1, 1, 1, x.shape[-1]).view(b, -1, x.shape[-1]))
x = x.view((b, -1, x.shape[-1]))
x = fold_operator(x) / divisor
x = x.to(device=device, dtype=torch_dtype)
return x

View File

@@ -0,0 +1,2 @@
from .stable_diffusion import SDPipeline
from .stable_diffusion_xl import SDXLPipeline

View File

@@ -0,0 +1,75 @@
from ..models import ModelManager
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDPipeline(torch.nn.Module):
def __init__(self):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
@torch.no_grad()
def __call__(
self,
model_manager: ModelManager,
prompter: SDPrompter,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
init_image=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
prompt_emb = prompter.encode_prompt(model_manager.text_encoder, prompt, clip_skip=clip_skip, device=model_manager.device)
negative_prompt_emb = prompter.encode_prompt(model_manager.text_encoder, negative_prompt, clip_skip=clip_skip, device=model_manager.device)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(device=model_manager.device, dtype=model_manager.torch_type)
latents = model_manager.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
# Classifier-free guidance
noise_pred_cond = model_manager.unet(latents, timestep, prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise_pred_uncond = model_manager.unet(latents, timestep, negative_prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image

View File

@@ -0,0 +1,126 @@
from ..models import ModelManager
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDXLPipeline(torch.nn.Module):
def __init__(self):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
@torch.no_grad()
def __call__(
self,
model_manager: ModelManager,
prompter: SDXLPrompter,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
init_image=None,
denoising_strength=1.0,
refining_strength=0.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
add_text_embeds, prompt_emb = prompter.encode_prompt(
model_manager.text_encoder,
model_manager.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=model_manager.device
)
if cfg_scale != 1.0:
negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt(
model_manager.text_encoder,
model_manager.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=model_manager.device
)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(
device=model_manager.device, dtype=model_manager.torch_type
)
latents = model_manager.vae_encoder(
image.to(torch.float32),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise = torch.randn(
(1, 4, height//8, width//8),
device=model_manager.device, dtype=model_manager.torch_type
)
latents = self.scheduler.add_noise(
latents.to(model_manager.torch_type),
noise,
timestep=self.scheduler.timesteps[0]
)
else:
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
# Classifier-free guidance
if timestep >= 1000 * refining_strength:
denoising_model = model_manager.unet
else:
denoising_model = model_manager.refiner
if cfg_scale != 1.0:
noise_pred_cond = denoising_model(
latents, timestep, prompt_emb,
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise_pred_uncond = denoising_model(
latents, timestep, negative_prompt_emb,
add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = denoising_model(
latents, timestep, prompt_emb,
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = latents.to(torch.float32)
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image

View File

@@ -0,0 +1,117 @@
from transformers import CLIPTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
import torch, os
from safetensors import safe_open
def tokenize_long_prompt(tokenizer, prompt):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
def load_textual_inversion(prompt):
# TODO: This module is not enabled now.
textual_inversion_files = os.listdir("models/textual_inversion")
embeddings_768 = []
embeddings_1280 = []
for file_name in textual_inversion_files:
if not file_name.endswith(".safetensors"):
continue
keyword = file_name[:-len(".safetensors")]
if keyword in prompt:
prompt = prompt.replace(keyword, "")
with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f:
for k in f.keys():
embedding = f.get_tensor(k).to(torch.float32)
if embedding.shape[-1] == 768:
embeddings_768.append(embedding)
elif embedding.shape[-1] == 1280:
embeddings_1280.append(embedding)
if len(embeddings_768)==0:
embeddings_768 = torch.zeros((0, 768))
else:
embeddings_768 = torch.concat(embeddings_768, dim=0)
if len(embeddings_1280)==0:
embeddings_1280 = torch.zeros((0, 1280))
else:
embeddings_1280 = torch.concat(embeddings_1280, dim=0)
return prompt, embeddings_768, embeddings_1280
class SDPrompter:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
class SDXLPrompter:
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
def encode_prompt(
self,
text_encoder: SDXLTextEncoder,
text_encoder_2: SDXLTextEncoder2,
prompt,
clip_skip=1,
clip_skip_2=2,
device="cuda"
):
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb

View File

@@ -0,0 +1,37 @@
from transformers import CLIPTokenizer
class SDTokenizer:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids

View File

@@ -0,0 +1,45 @@
from transformers import CLIPTokenizer
from .sd_tokenizer import SDTokenizer
class SDXLTokenizer(SDTokenizer):
def __init__(self):
super().__init__()
class SDXLTokenizer2:
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids

View File

@@ -0,0 +1,60 @@
import torch, math
class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012):
self.num_train_timesteps = num_train_timesteps
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
self.set_timesteps(10)
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = round(self.num_train_timesteps * denoising_strength) - 1
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = [max_timestep]
else:
step_length = max_timestep / (num_inference_steps - 1)
self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
weight_e = -math.sqrt((1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(1 / alpha_prod_t)
return prev_sample
def step(self, model_output, timestep, sample):
alpha_prod_t = self.alphas_cumprod[timestep]
timestep_id = self.timesteps.index(timestep)
if timestep_id + 1 < len(self.timesteps):
timestep_prev = self.timesteps[timestep_id + 1]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
else:
alpha_prod_t_prev = 1.0
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
def return_to_timestep(self, timestep, sample, sample_stablized):
alpha_prod_t = self.alphas_cumprod[timestep]
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
return noise_pred
def add_noise(self, original_samples, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples