support lora

This commit is contained in:
Artiprocher
2024-01-10 14:34:02 +08:00
parent 8a460497fa
commit 9698e3988f
4 changed files with 87 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .sd_lora import SDLoRA
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
@@ -50,6 +51,10 @@ class ModelManager:
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
return param_name in state_dict
def is_sd_lora(self, state_dict):
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
return param_name in state_dict
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
@@ -138,6 +143,10 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_sd_lora(self, state_dict, alpha):
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -165,7 +174,7 @@ class ModelManager:
self.textual_inversion_dict[keyword] = (tokens, embeddings)
break
def load_model(self, file_path, components=None):
def load_model(self, file_path, components=None, lora_alphas=[]):
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
@@ -175,14 +184,16 @@ class ModelManager:
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_sd_lora(state_dict):
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
def load_models(self, file_path_list):
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:
self.load_model(file_path)
self.load_model(file_path, lora_alphas=lora_alphas)
def to(self, device):
for component in self.model: