From 9698e3988f4e5e3034fb99f2648e1e7baaff9021 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 10 Jan 2024 14:34:02 +0800 Subject: [PATCH] support lora --- diffsynth/models/__init__.py | 17 ++++++++-- diffsynth/models/sd_lora.py | 60 ++++++++++++++++++++++++++++++++++++ pages/1_Image_Creator.py | 1 + pages/2_Video_Creator.py | 14 +++++++-- 4 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 diffsynth/models/sd_lora.py diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index e3521f3..35c6472 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -5,6 +5,7 @@ from .sd_text_encoder import SDTextEncoder from .sd_unet import SDUNet from .sd_vae_encoder import SDVAEEncoder from .sd_vae_decoder import SDVAEDecoder +from .sd_lora import SDLoRA from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sdxl_unet import SDXLUNet @@ -50,6 +51,10 @@ class ModelManager: param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight" return param_name in state_dict + def is_sd_lora(self, state_dict): + param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight" + return param_name in state_dict + def load_stable_diffusion(self, state_dict, components=None, file_path=""): component_dict = { "text_encoder": SDTextEncoder, @@ -138,6 +143,10 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_sd_lora(self, state_dict, alpha): + SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device) + SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device) + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -165,7 +174,7 @@ class ModelManager: self.textual_inversion_dict[keyword] = (tokens, embeddings) break - def load_model(self, file_path, components=None): + def load_model(self, file_path, components=None, lora_alphas=[]): state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype) if self.is_animatediff(state_dict): self.load_animatediff(state_dict, file_path=file_path) @@ -175,14 +184,16 @@ class ModelManager: self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path) elif self.is_stable_diffusion(state_dict): self.load_stable_diffusion(state_dict, components=components, file_path=file_path) + elif self.is_sd_lora(state_dict): + self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0)) elif self.is_beautiful_prompt(state_dict): self.load_beautiful_prompt(state_dict, file_path=file_path) elif self.is_RIFE(state_dict): self.load_RIFE(state_dict, file_path=file_path) - def load_models(self, file_path_list): + def load_models(self, file_path_list, lora_alphas=[]): for file_path in file_path_list: - self.load_model(file_path) + self.load_model(file_path, lora_alphas=lora_alphas) def to(self, device): for component in self.model: diff --git a/diffsynth/models/sd_lora.py b/diffsynth/models/sd_lora.py new file mode 100644 index 0000000..3b7ecac --- /dev/null +++ b/diffsynth/models/sd_lora.py @@ -0,0 +1,60 @@ +import torch +from .sd_unet import SDUNetStateDictConverter, SDUNet +from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder + + +class SDLoRA: + def __init__(self): + pass + + def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"): + special_keys = { + "down.blocks": "down_blocks", + "up.blocks": "up_blocks", + "mid.block": "mid_block", + "proj.in": "proj_in", + "proj.out": "proj_out", + "transformer.blocks": "transformer_blocks", + "to.q": "to_q", + "to.k": "to_k", + "to.v": "to_v", + "to.out": "to_out", + } + state_dict_ = {} + for key in state_dict: + if ".lora_up" not in key: + continue + if not key.startswith(lora_prefix): + continue + weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) + weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight" + for special_key in special_keys: + target_name = target_name.replace(special_key, special_keys[special_key]) + state_dict_[target_name] = lora_weight.cpu() + return state_dict_ + + def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"): + state_dict_unet = unet.state_dict() + state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device) + state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora) + if len(state_dict_lora) > 0: + for name in state_dict_lora: + state_dict_unet[name] += state_dict_lora[name].to(device=device) + unet.load_state_dict(state_dict_unet) + + def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"): + state_dict_text_encoder = text_encoder.state_dict() + state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device) + state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora) + if len(state_dict_lora) > 0: + for name in state_dict_lora: + state_dict_text_encoder[name] += state_dict_lora[name].to(device=device) + text_encoder.load_state_dict(state_dict_text_encoder) + diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 9314f53..8c735fa 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -2,6 +2,7 @@ import torch, os, io import numpy as np from PIL import Image import streamlit as st +st.set_page_config(layout="wide") from streamlit_drawable_canvas import st_canvas from diffsynth.models import ModelManager from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline diff --git a/pages/2_Video_Creator.py b/pages/2_Video_Creator.py index ecb3bbc..15ebb5e 100644 --- a/pages/2_Video_Creator.py +++ b/pages/2_Video_Creator.py @@ -1,4 +1,5 @@ import streamlit as st +st.set_page_config(layout="wide") from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames import torch, os, json import numpy as np @@ -9,11 +10,11 @@ class Runner: pass - def load_pipeline(self, model_list, textual_inversion_folder, device, controlnet_units): + def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units): # Load models model_manager = ModelManager(torch_dtype=torch.float16, device=device) model_manager.load_textual_inversions(textual_inversion_folder) - model_manager.load_models(model_list) + model_manager.load_models(model_list, lora_alphas=lora_alphas) pipe = SDVideoPipeline.from_model_manager( model_manager, [ @@ -100,6 +101,7 @@ config = { "model_list": [], "textual_inversion_folder": "models/textual_inversion", "device": "cuda", + "lora_alphas": [], "controlnet_units": [] }, "data": { @@ -122,6 +124,14 @@ with st.expander("Model", expanded=True): animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff")) if animatediff_ckpt != "None": config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt)) + column_lora, column_lora_alpha = st.columns([2, 1]) + with column_lora: + sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora")) + with column_lora_alpha: + lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1) + if sd_lora_ckpt != "None": + config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt)) + config["models"]["lora_alphas"].append(lora_alpha) with st.expander("Data", expanded=True):