mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support lora
This commit is contained in:
@@ -5,6 +5,7 @@ from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .sd_lora import SDLoRA
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
@@ -50,6 +51,10 @@ class ModelManager:
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_sd_lora(self, state_dict):
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
@@ -138,6 +143,10 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_sd_lora(self, state_dict, alpha):
|
||||
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -165,7 +174,7 @@ class ModelManager:
|
||||
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None):
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
@@ -175,14 +184,16 @@ class ModelManager:
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_sd_lora(state_dict):
|
||||
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
elif self.is_RIFE(state_dict):
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list):
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path)
|
||||
self.load_model(file_path, lora_alphas=lora_alphas)
|
||||
|
||||
def to(self, device):
|
||||
for component in self.model:
|
||||
|
||||
Reference in New Issue
Block a user