mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
initial version
This commit is contained in:
126
diffsynth/models/__init__.py
Normal file
126
diffsynth/models/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_type=torch.float16, device="cuda"):
|
||||
self.torch_type = torch_type
|
||||
self.device = device
|
||||
self.model = {}
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
return True
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
"unet": SDUNet,
|
||||
"vae_decoder": SDVAEDecoder,
|
||||
"vae_encoder": SDVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_type).to(self.device)
|
||||
|
||||
def load_stable_diffusion_xl(self, state_dict, components=None):
|
||||
component_dict = {
|
||||
"text_encoder": SDXLTextEncoder,
|
||||
"text_encoder_2": SDXLTextEncoder2,
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component in ["vae_decoder", "vae_encoder"]:
|
||||
# These two model will output nan when float16 is enabled.
|
||||
# The precision problem happens in the last three resnet blocks.
|
||||
# I do not know how to solve this problem.
|
||||
self.model[component].to(torch.float32).to(self.device)
|
||||
else:
|
||||
self.model[component].to(self.torch_type).to(self.device)
|
||||
|
||||
def load_from_safetensors(self, file_path, components=None):
|
||||
state_dict = load_state_dict_from_safetensors(file_path)
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
self.load_stable_diffusion_xl(state_dict, components=components)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
self.model[component].to(device)
|
||||
|
||||
def __getattr__(self, __name):
|
||||
if __name in self.model:
|
||||
return self.model[__name]
|
||||
else:
|
||||
return super.__getattribute__(__name)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path):
|
||||
return torch.load(file_path, map_location="cpu")
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-6:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
38
diffsynth/models/attention.py
Normal file
38
diffsynth/models/attention.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
320
diffsynth/models/sd_text_encoder.py
Normal file
320
diffsynth/models/sd_text_encoder.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||
super().__init__()
|
||||
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
|
||||
def quickGELU(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward(self, hidden_states, attn_mask):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
if self.use_quick_gelu:
|
||||
hidden_states = self.quickGELU(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
1090
diffsynth/models/sd_unet.py
Normal file
1090
diffsynth/models/sd_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
330
diffsynth/models/sd_vae_decoder.py
Normal file
330
diffsynth/models/sd_vae_decoder.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import Tiler
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SDVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
|
||||
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
sample = sample / self.scaling_factor
|
||||
hidden_states = self.post_quant_conv(sample)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"post_quant_conv": "post_quant_conv",
|
||||
"decoder.conv_in": "conv_in",
|
||||
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
||||
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
||||
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
||||
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
||||
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
||||
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
||||
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
||||
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
||||
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
||||
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
||||
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
||||
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
||||
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
||||
"decoder.conv_norm_out": "conv_norm_out",
|
||||
"decoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("decoder.up_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
|
||||
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
258
diffsynth/models/sd_vae_encoder.py
Normal file
258
diffsynth/models/sd_vae_encoder.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import Tiler
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDVAEEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"quant_conv": "quant_conv",
|
||||
"encoder.conv_in": "conv_in",
|
||||
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
||||
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
||||
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
||||
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
||||
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
||||
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
||||
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
||||
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
||||
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
||||
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
||||
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
||||
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
||||
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
||||
"encoder.conv_norm_out": "conv_norm_out",
|
||||
"encoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("encoder.down_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.quant_conv.bias": "quant_conv.bias",
|
||||
"first_stage_model.quant_conv.weight": "quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
757
diffsynth/models/sdxl_text_encoder.py
Normal file
757
diffsynth/models/sdxl_text_encoder.py
Normal file
@@ -0,0 +1,757 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class SDXLTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# The text encoder is different to that in Stable Diffusion 1.x.
|
||||
# It does not include final_layer_norm.
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
return embeds
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLTextEncoder2(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
# text_projection
|
||||
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
pooled_embeds = self.text_projection(pooled_embeds)
|
||||
return pooled_embeds, hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
class SDXLTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class SDXLTextEncoder2StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"text_projection.weight": "text_projection.weight"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
|
||||
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
|
||||
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
|
||||
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.1.model.positional_embedding":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
elif name == "conditioner.embedders.1.model.text_projection":
|
||||
param = param.T
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
length = param.shape[0] // 3
|
||||
for i, rename in enumerate(rename_dict[name]):
|
||||
state_dict_[rename] = param[i*length: i*length+length]
|
||||
return state_dict_
|
||||
1897
diffsynth/models/sdxl_unet.py
Normal file
1897
diffsynth/models/sdxl_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
15
diffsynth/models/sdxl_vae_decoder.py
Normal file
15
diffsynth/models/sdxl_vae_decoder.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class SDXLVAEDecoder(SDVAEDecoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
15
diffsynth/models/sdxl_vae_encoder.py
Normal file
15
diffsynth/models/sdxl_vae_encoder.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SDXLVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDXLVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
75
diffsynth/models/tiler.py
Normal file
75
diffsynth/models/tiler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Tiler(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def mask(self, height, width, line_width):
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / line_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
def forward(self, forward_fn, x, tile_size, tile_stride, batch_size=1, inter_device="cpu", inter_dtype=torch.float32):
|
||||
# Prepare
|
||||
device = x.device
|
||||
torch_dtype = x.dtype
|
||||
|
||||
# tile
|
||||
b, c_in, h_in, w_in = x.shape
|
||||
x = x.to(device=inter_device, dtype=inter_dtype)
|
||||
fold_params = {
|
||||
"kernel_size": (tile_size, tile_size),
|
||||
"stride": (tile_stride, tile_stride)
|
||||
}
|
||||
unfold_operator = torch.nn.Unfold(**fold_params)
|
||||
x = unfold_operator(x)
|
||||
x = x.view((b, c_in, tile_size, tile_size, -1))
|
||||
|
||||
# inference
|
||||
x_out_stack = []
|
||||
for tile_id in range(0, x.shape[-1], batch_size):
|
||||
|
||||
# process input
|
||||
next_tile_id = min(tile_id + batch_size, x.shape[-1])
|
||||
x_in = x[:, :, :, :, tile_id: next_tile_id]
|
||||
x_in = x_in.to(device=device, dtype=torch_dtype)
|
||||
x_in = x_in.permute(4, 0, 1, 2, 3)
|
||||
x_in = x_in.view((x_in.shape[0]*x_in.shape[1], x_in.shape[2], x_in.shape[3], x_in.shape[4]))
|
||||
|
||||
# process output
|
||||
x_out = forward_fn(x_in)
|
||||
x_out = x_out.view((next_tile_id - tile_id, b, x_out.shape[1], x_out.shape[2], x_out.shape[3]))
|
||||
x_out = x_out.permute(1, 2, 3, 4, 0)
|
||||
x_out = x_out.to(device=inter_device, dtype=inter_dtype)
|
||||
x_out_stack.append(x_out)
|
||||
|
||||
x = torch.concat(x_out_stack, dim=-1)
|
||||
|
||||
# untile
|
||||
in2out_scale = x.shape[2] / tile_size
|
||||
h_out, w_out = int(h_in * in2out_scale), int(w_in * in2out_scale)
|
||||
|
||||
mask = self.mask(int(tile_size * in2out_scale), int(tile_size * in2out_scale), int(tile_stride * in2out_scale * 0.5))
|
||||
mask = mask.to(device=inter_device, dtype=inter_dtype)
|
||||
mask = mask.reshape((1, 1, mask.shape[0], mask.shape[1], 1))
|
||||
x = x * mask
|
||||
|
||||
fold_params = {
|
||||
"kernel_size": (int(tile_size * in2out_scale), int(tile_size * in2out_scale)),
|
||||
"stride": (int(tile_stride * in2out_scale), int(tile_stride * in2out_scale))
|
||||
}
|
||||
fold_operator = torch.nn.Fold(output_size=(h_out, w_out), **fold_params)
|
||||
divisor = fold_operator(mask.repeat(1, 1, 1, 1, x.shape[-1]).view(b, -1, x.shape[-1]))
|
||||
|
||||
x = x.view((b, -1, x.shape[-1]))
|
||||
x = fold_operator(x) / divisor
|
||||
x = x.to(device=device, dtype=torch_dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user