rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -1,814 +1 @@
import torch, os, json
from safetensors import safe_open
from typing_extensions import Literal, TypeAlias
from typing import List
from .downloader import download_from_huggingface, download_from_modelscope
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .sd_lora import SDLoRA
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .kolors_text_encoder import ChatGLMModel
preset_models_on_huggingface = {
"HunyuanDiT": [
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
"stable-video-diffusion-img2vid-xt": [
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
"HunyuanDiT": [
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
# Stable Video Diffusion
"stable-video-diffusion-img2vid-xt": [
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
# ExVideo
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
"AingDiffusion_v12": [
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
],
"Flat2DAnimerge_v45Sharp": [
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# RIFE
"RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Translator
"opus-mt-zh-en": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",
"Flat2DAnimerge_v45Sharp",
"TextualInversion_VeryBadImageNegative_v1.3",
"StableDiffusionXL_v1",
"BluePencilXL_v200",
"StableDiffusionXL_Turbo",
"ControlNet_v11f1p_sd15_depth",
"ControlNet_v11p_sd15_softedge",
"ControlNet_v11f1e_sd15_tile",
"ControlNet_v11p_sd15_lineart",
"AnimateDiff_v2",
"AnimateDiff_xl_beta",
"RIFE",
"BeautifulPrompt",
"opus-mt-zh-en",
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
]
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = {}
self.model_path = {}
self.textual_inversion_dict = {}
downloaded_files = download_models(model_id_list, downloading_priority)
self.load_models(downloaded_files + file_path_list)
def load_model_from_origin(
self,
download_from: Preset_model_website = "ModelScope",
model_id = "",
origin_file_path = "",
local_dir = ""
):
website_to_download_fn[download_from](model_id, origin_file_path, local_dir)
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
self.load_model(file_to_download)
def is_stable_video_diffusion(self, state_dict):
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
return param_name in state_dict
def is_RIFE(self, state_dict):
param_name = "block_tea.convblock3.0.1.weight"
return param_name in state_dict or ("module." + param_name) in state_dict
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
if self.is_stabe_diffusion_xl(state_dict):
return False
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
return param_name in state_dict
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
return param_name in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
return param_name in state_dict
def is_animatediff_xl(self, state_dict):
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
return param_name in state_dict
def is_sd_lora(self, state_dict):
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
return param_name in state_dict
def is_translator(self, state_dict):
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 258
def is_ipadapter(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
def is_ipadapter_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 521
def is_ipadapter_xl(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
def is_ipadapter_xl_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 777
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
return param_name in state_dict
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict and param_name_ in state_dict
def is_hunyuan_dit(self, state_dict):
param_name = "final_layer.adaLN_modulation.1.weight"
return param_name in state_dict
def is_diffusers_vae(self, state_dict):
param_name = "quant_conv.weight"
return param_name in state_dict
def is_ExVideo_StableVideoDiffusion(self, state_dict):
param_name = "blocks.185.positional_embedding.embeddings"
return param_name in state_dict
def is_stable_diffusion_3(self, state_dict):
param_names = [
"text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
"text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight",
"first_stage_model.encoder.mid.block_2.norm2.weight",
"first_stage_model.decoder.mid.block_2.norm2.weight",
]
for param_name in param_names:
if param_name not in state_dict:
return False
return True
def is_stable_diffusion_3_t5(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_kolors_text_encoder(self, file_path):
file_list = os.listdir(file_path)
if "config.json" in file_list:
try:
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if config.get("model_type") == "chatglm":
return True
except:
pass
return False
def is_kolors_unet(self, state_dict):
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
"unet": SVDUNet,
"vae_decoder": SVDVAEDecoder,
"vae_encoder": SVDVAEEncoder,
}
if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "unet":
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
"unet": SDUNet,
"vae_decoder": SDVAEDecoder,
"vae_encoder": SDVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "text_encoder":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDXLTextEncoder,
"text_encoder_2": SDXLTextEncoder2,
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
if component in ["vae_decoder", "vae_encoder"]:
# These two model will output nan when float16 is enabled.
# The precision problem happens in the last three resnet blocks.
# I do not know how to solve this problem.
self.model[component].to(torch.float32).to(self.device)
else:
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_controlnet(self, state_dict, file_path=""):
component = "controlnet"
if component not in self.model:
self.model[component] = []
self.model_path[component] = []
model = SDControlNet()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component].append(model)
self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path=""):
component = "motion_modules"
model = SDMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_animatediff_xl(self, state_dict, file_path=""):
component = "motion_modules_xl"
model = SDXLMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
from transformers import AutoModelForCausalLM
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
).to(self.device).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_RIFE(self, state_dict, file_path=""):
component = "RIFE"
from ..extensions.RIFE import IFNet
model = IFNet().eval()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_sd_lora(self, state_dict, alpha):
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
def load_translator(self, state_dict, file_path=""):
# This model is lightweight, we do not place it on GPU.
component = "translator"
from transformers import AutoModelForSeq2SeqLM
model_folder = os.path.dirname(file_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter(self, state_dict, file_path=""):
component = "ipadapter"
model = SDIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_image_encoder"
model = IpAdapterCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl(self, state_dict, file_path=""):
component = "ipadapter_xl"
model = SDXLIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_xl_image_encoder"
model = IpAdapterXLCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_clip_text_encoder"
model = HunyuanDiTCLIPTextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_t5_text_encoder"
model = HunyuanDiTT5TextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit(self, state_dict, file_path=""):
component = "hunyuan_dit"
model = HunyuanDiT()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_diffusers_vae(self, state_dict, file_path=""):
# TODO: detect SD and SDXL
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
unet_state_dict = self.model["unet"].state_dict()
self.model["unet"].to("cpu")
del self.model["unet"]
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
self.model["unet"].load_state_dict(state_dict, strict=False)
self.model["unet"].to(self.torch_dtype).to(self.device)
def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
component_dict = {
"sd3_text_encoder_1": SD3TextEncoder1,
"sd3_text_encoder_2": SD3TextEncoder2,
"sd3_text_encoder_3": SD3TextEncoder3,
"sd3_dit": SD3DiT,
"sd3_vae_decoder": SD3VAEDecoder,
"sd3_vae_encoder": SD3VAEEncoder,
}
if components is None:
components = ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_decoder", "sd3_vae_encoder"]
for component in components:
if component == "sd3_text_encoder_3":
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
continue
if component == "sd3_text_encoder_1":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
component = "sd3_text_encoder_3"
model = SD3TextEncoder3()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
component = "kolors_text_encoder"
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
model = model.to(dtype=self.torch_dtype, device=self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_unet(self, state_dict, file_path=""):
component = "kolors_unet"
model = SDXLUNet(is_kolors=True)
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += self.search_for_embeddings(state_dict[k])
return embeddings
def load_textual_inversions(self, folder):
# Store additional tokens here
self.textual_inversion_dict = {}
# Load every textual inversion file
for file_name in os.listdir(folder):
if os.path.isdir(os.path.join(folder, file_name)) or \
not (file_name.endswith(".bin") or \
file_name.endswith(".safetensors") or \
file_name.endswith(".pth") or \
file_name.endswith(".pt")):
continue
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
# Search for embeddings
for embeddings in self.search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
break
def load_model(self, file_path, components=None, lora_alphas=[]):
if os.path.isdir(file_path):
if self.is_kolors_text_encoder(file_path):
self.load_kolors_text_encoder(file_path=file_path)
return
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
elif self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
elif self.is_animatediff_xl(state_dict):
self.load_animatediff_xl(state_dict, file_path=file_path)
elif self.is_controlnet(state_dict):
self.load_controlnet(state_dict, file_path=file_path)
elif self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_sd_lora(state_dict):
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path)
elif self.is_ipadapter(state_dict):
self.load_ipadapter(state_dict, file_path=file_path)
elif self.is_ipadapter_image_encoder(state_dict):
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
elif self.is_ipadapter_xl(state_dict):
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit(state_dict):
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
elif self.is_stable_diffusion_3(state_dict):
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion_3_t5(state_dict):
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
elif self.is_kolors_unet(state_dict):
self.load_kolors_unet(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:
self.load_model(file_path, lora_alphas=lora_alphas)
def to(self, device):
for component in self.model:
if isinstance(self.model[component], list):
for model in self.model[component]:
model.to(device)
else:
self.model[component].to(device)
torch.cuda.empty_cache()
def get_model_with_model_path(self, model_path):
for component in self.model_path:
if isinstance(self.model_path[component], str):
if os.path.samefile(self.model_path[component], model_path):
return self.model[component]
elif isinstance(self.model_path[component], list):
for i, model_path_ in enumerate(self.model_path[component]):
if os.path.samefile(model_path_, model_path):
return self.model[component][i]
raise ValueError(f"Please load model {model_path} before you use it.")
def __getattr__(self, __name):
if __name in self.model:
return self.model[__name]
else:
return super.__getattribute__(__name)
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
from .model_manager import *

View File

@@ -1,15 +1,18 @@
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
@@ -21,8 +24,43 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files

View File

@@ -1,5 +1,4 @@
from .attention import Attention
from .tiler import TileWorker
from einops import repeat, rearrange
import math
import torch
@@ -399,7 +398,8 @@ class HunyuanDiT(torch.nn.Module):
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return HunyuanDiTStateDictConverter()

View File

@@ -79,7 +79,8 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return HunyuanDiTCLIPTextEncoderStateDictConverter()
@@ -131,7 +132,8 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return HunyuanDiTT5TextEncoderStateDictConverter()

195
diffsynth/models/lora.py Normal file
View File

@@ -0,0 +1,195 @@
import torch
from .sd_unet import SDUNet
from .sdxl_unet import SDXLUNet
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT
from .hunyuan_dit import HunyuanDiT
class LoRAFromCivitai:
def __init__(self):
self.supported_model_classes = []
self.lora_prefix = []
self.renamed_lora_prefix = {}
self.special_keys = {}
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
for special_key in self.special_keys:
target_name = target_name.replace(special_key, self.special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
if model_resource == "diffusers":
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
elif model_resource == "civitai":
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
for model_resource in ["diffusers", "civitai"]:
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return lora_prefix, model_resource
except:
pass
return None
class SDLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDUNet, SDTextEncoder]
self.lora_prefix = ["lora_unet_", "lora_te_"]
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
}
class SDXLLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
self.renamed_lora_prefix = {"lora_te2_": "2"}
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "conditioner.embedders.0.transformer.text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
for name, param in state_dict_model.items():
torch_dtype = param.dtype
device = param.device
break
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return "", ""
except:
pass
return None

View File

@@ -0,0 +1,536 @@
import os, torch, hashlib, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List
from .downloader import download_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
model.load_state_dict(model_state_dict)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
model = model.to(device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name):
self.architecture_dict[architecture] = (huggingface_lib, model_name)
def match(self, file_path="", state_dict={}):
if os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]:
huggingface_lib, model_name = self.architecture_dict[architecture]
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(downloaded_files + file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None):
print(f"Loading models from: {file_path}")
if os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None):
for file_path in file_path_list:
self.load_model(file_path, model_names)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -228,7 +228,8 @@ class SD3DiT(torch.nn.Module):
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3DiTStateDictConverter()

View File

@@ -19,7 +19,8 @@ class SD3TextEncoder1(SDTextEncoder):
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds, hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder1StateDictConverter()
@@ -28,7 +29,8 @@ class SD3TextEncoder2(SDXLTextEncoder2):
def __init__(self):
super().__init__()
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder2StateDictConverter()
@@ -72,7 +74,8 @@ class SD3TextEncoder3(T5EncoderModel):
prompt_emb = outputs.last_hidden_state
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder3StateDictConverter()

View File

@@ -76,5 +76,6 @@ class SD3VAEDecoder(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEDecoderStateDictConverter()

View File

@@ -90,5 +90,6 @@ class SD3VAEEncoder(torch.nn.Module):
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEEncoderStateDictConverter()

View File

@@ -99,7 +99,7 @@ class SDControlNet(torch.nn.Module):
tiled=False, tile_size=64, tile_stride=32,
):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
@@ -134,7 +134,8 @@ class SDControlNet(torch.nn.Module):
return controlnet_res_stack
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDControlNetStateDictConverter()

View File

@@ -47,7 +47,8 @@ class SDIpAdapter(torch.nn.Module):
}
return ip_kv_dict
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDIpAdapterStateDictConverter()

View File

@@ -1,60 +0,0 @@
import torch
from .sd_unet import SDUNetStateDictConverter, SDUNet
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
class SDLoRA:
def __init__(self):
pass
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
}
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
for special_key in special_keys:
target_name = target_name.replace(special_key, special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_unet = unet.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_unet[name] += state_dict_lora[name].to(device=device)
unet.load_state_dict(state_dict_unet)
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_text_encoder = text_encoder.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
text_encoder.load_state_dict(state_dict_text_encoder)

View File

@@ -144,7 +144,8 @@ class SDMotionModel(torch.nn.Module):
def forward(self):
pass
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDMotionModelStateDictConverter()

View File

@@ -71,7 +71,8 @@ class SDTextEncoder(torch.nn.Module):
embeds = self.final_layer_norm(embeds)
return embeds
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDTextEncoderStateDictConverter()

View File

@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
# 2. pre-process
@@ -342,7 +342,8 @@ class SDUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDUNetStateDictConverter()

View File

@@ -90,6 +90,8 @@ class SDVAEDecoder(torch.nn.Module):
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
original_dtype = sample.dtype
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
@@ -110,10 +112,12 @@ class SDVAEDecoder(torch.nn.Module):
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states.to(original_dtype)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEDecoderStateDictConverter()

View File

@@ -50,6 +50,8 @@ class SDVAEEncoder(torch.nn.Module):
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
original_dtype = sample.dtype
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
@@ -71,6 +73,7 @@ class SDVAEEncoder(torch.nn.Module):
hidden_states = self.quant_conv(hidden_states)
hidden_states = hidden_states[:, :4]
hidden_states *= self.scaling_factor
hidden_states = hidden_states.to(original_dtype)
return hidden_states
@@ -91,7 +94,8 @@ class SDVAEEncoder(torch.nn.Module):
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEEncoderStateDictConverter()

View File

@@ -96,7 +96,8 @@ class SDXLIpAdapter(torch.nn.Module):
}
return ip_kv_dict
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLIpAdapterStateDictConverter()

View File

@@ -49,7 +49,8 @@ class SDXLMotionModel(torch.nn.Module):
def forward(self):
pass
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDMotionModelStateDictConverter()

View File

@@ -36,7 +36,8 @@ class SDXLTextEncoder(torch.nn.Module):
break
return embeds
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLTextEncoderStateDictConverter()
@@ -80,7 +81,8 @@ class SDXLTextEncoder2(torch.nn.Module):
pooled_embeds = self.text_projection(pooled_embeds)
return pooled_embeds, hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLTextEncoder2StateDictConverter()

View File

@@ -91,7 +91,7 @@ class SDXLUNet(torch.nn.Module):
**kwargs
):
# 1. time
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
@@ -133,7 +133,8 @@ class SDXLUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLUNetStateDictConverter()
@@ -197,7 +198,10 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
@@ -1889,4 +1893,7 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_

View File

@@ -2,14 +2,23 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
class SDXLVAEDecoder(SDVAEDecoder):
def __init__(self):
def __init__(self, upcast_to_float32=True):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLVAEDecoderStateDictConverter()
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}

View File

@@ -2,14 +2,23 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
class SDXLVAEEncoder(SDVAEEncoder):
def __init__(self):
def __init__(self, upcast_to_float32=True):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLVAEEncoderStateDictConverter()
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}

View File

@@ -44,7 +44,8 @@ class SVDImageEncoder(torch.nn.Module):
embeds = self.visual_projection(embeds)
return embeds
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDImageEncoderStateDictConverter()

View File

@@ -407,7 +407,8 @@ class SVDUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDUNetStateDictConverter()

View File

@@ -199,7 +199,8 @@ class SVDVAEDecoder(torch.nn.Module):
return values
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDVAEDecoderStateDictConverter()

View File

@@ -6,7 +6,8 @@ class SVDVAEEncoder(SDVAEEncoder):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDVAEEncoderStateDictConverter()