mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
support lora
This commit is contained in:
@@ -5,6 +5,7 @@ from .sd_text_encoder import SDTextEncoder
|
|||||||
from .sd_unet import SDUNet
|
from .sd_unet import SDUNet
|
||||||
from .sd_vae_encoder import SDVAEEncoder
|
from .sd_vae_encoder import SDVAEEncoder
|
||||||
from .sd_vae_decoder import SDVAEDecoder
|
from .sd_vae_decoder import SDVAEDecoder
|
||||||
|
from .sd_lora import SDLoRA
|
||||||
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
from .sdxl_unet import SDXLUNet
|
from .sdxl_unet import SDXLUNet
|
||||||
@@ -50,6 +51,10 @@ class ModelManager:
|
|||||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||||
return param_name in state_dict
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_sd_lora(self, state_dict):
|
||||||
|
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||||
component_dict = {
|
component_dict = {
|
||||||
"text_encoder": SDTextEncoder,
|
"text_encoder": SDTextEncoder,
|
||||||
@@ -138,6 +143,10 @@ class ModelManager:
|
|||||||
self.model[component] = model
|
self.model[component] = model
|
||||||
self.model_path[component] = file_path
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_sd_lora(self, state_dict, alpha):
|
||||||
|
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||||
|
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||||
|
|
||||||
def search_for_embeddings(self, state_dict):
|
def search_for_embeddings(self, state_dict):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
@@ -165,7 +174,7 @@ class ModelManager:
|
|||||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||||
break
|
break
|
||||||
|
|
||||||
def load_model(self, file_path, components=None):
|
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||||
if self.is_animatediff(state_dict):
|
if self.is_animatediff(state_dict):
|
||||||
self.load_animatediff(state_dict, file_path=file_path)
|
self.load_animatediff(state_dict, file_path=file_path)
|
||||||
@@ -175,14 +184,16 @@ class ModelManager:
|
|||||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||||
elif self.is_stable_diffusion(state_dict):
|
elif self.is_stable_diffusion(state_dict):
|
||||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||||
|
elif self.is_sd_lora(state_dict):
|
||||||
|
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
||||||
elif self.is_beautiful_prompt(state_dict):
|
elif self.is_beautiful_prompt(state_dict):
|
||||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||||
elif self.is_RIFE(state_dict):
|
elif self.is_RIFE(state_dict):
|
||||||
self.load_RIFE(state_dict, file_path=file_path)
|
self.load_RIFE(state_dict, file_path=file_path)
|
||||||
|
|
||||||
def load_models(self, file_path_list):
|
def load_models(self, file_path_list, lora_alphas=[]):
|
||||||
for file_path in file_path_list:
|
for file_path in file_path_list:
|
||||||
self.load_model(file_path)
|
self.load_model(file_path, lora_alphas=lora_alphas)
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
for component in self.model:
|
for component in self.model:
|
||||||
|
|||||||
60
diffsynth/models/sd_lora.py
Normal file
60
diffsynth/models/sd_lora.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
||||||
|
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class SDLoRA:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
||||||
|
special_keys = {
|
||||||
|
"down.blocks": "down_blocks",
|
||||||
|
"up.blocks": "up_blocks",
|
||||||
|
"mid.block": "mid_block",
|
||||||
|
"proj.in": "proj_in",
|
||||||
|
"proj.out": "proj_out",
|
||||||
|
"transformer.blocks": "transformer_blocks",
|
||||||
|
"to.q": "to_q",
|
||||||
|
"to.k": "to_k",
|
||||||
|
"to.v": "to_v",
|
||||||
|
"to.out": "to_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_up" not in key:
|
||||||
|
continue
|
||||||
|
if not key.startswith(lora_prefix):
|
||||||
|
continue
|
||||||
|
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||||
|
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
||||||
|
for special_key in special_keys:
|
||||||
|
target_name = target_name.replace(special_key, special_keys[special_key])
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
||||||
|
state_dict_unet = unet.state_dict()
|
||||||
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
||||||
|
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
||||||
|
if len(state_dict_lora) > 0:
|
||||||
|
for name in state_dict_lora:
|
||||||
|
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
||||||
|
unet.load_state_dict(state_dict_unet)
|
||||||
|
|
||||||
|
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
||||||
|
state_dict_text_encoder = text_encoder.state_dict()
|
||||||
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
||||||
|
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
||||||
|
if len(state_dict_lora) > 0:
|
||||||
|
for name in state_dict_lora:
|
||||||
|
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
||||||
|
text_encoder.load_state_dict(state_dict_text_encoder)
|
||||||
|
|
||||||
@@ -2,6 +2,7 @@ import torch, os, io
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
st.set_page_config(layout="wide")
|
||||||
from streamlit_drawable_canvas import st_canvas
|
from streamlit_drawable_canvas import st_canvas
|
||||||
from diffsynth.models import ModelManager
|
from diffsynth.models import ModelManager
|
||||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
st.set_page_config(layout="wide")
|
||||||
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
||||||
import torch, os, json
|
import torch, os, json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -9,11 +10,11 @@ class Runner:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(self, model_list, textual_inversion_folder, device, controlnet_units):
|
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||||
# Load models
|
# Load models
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||||
model_manager.load_models(model_list)
|
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||||
pipe = SDVideoPipeline.from_model_manager(
|
pipe = SDVideoPipeline.from_model_manager(
|
||||||
model_manager,
|
model_manager,
|
||||||
[
|
[
|
||||||
@@ -100,6 +101,7 @@ config = {
|
|||||||
"model_list": [],
|
"model_list": [],
|
||||||
"textual_inversion_folder": "models/textual_inversion",
|
"textual_inversion_folder": "models/textual_inversion",
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
|
"lora_alphas": [],
|
||||||
"controlnet_units": []
|
"controlnet_units": []
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
@@ -122,6 +124,14 @@ with st.expander("Model", expanded=True):
|
|||||||
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
||||||
if animatediff_ckpt != "None":
|
if animatediff_ckpt != "None":
|
||||||
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
||||||
|
column_lora, column_lora_alpha = st.columns([2, 1])
|
||||||
|
with column_lora:
|
||||||
|
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
|
||||||
|
with column_lora_alpha:
|
||||||
|
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
|
||||||
|
if sd_lora_ckpt != "None":
|
||||||
|
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
|
||||||
|
config["models"]["lora_alphas"].append(lora_alpha)
|
||||||
|
|
||||||
|
|
||||||
with st.expander("Data", expanded=True):
|
with st.expander("Data", expanded=True):
|
||||||
|
|||||||
Reference in New Issue
Block a user