This commit is contained in:
Artiprocher
2023-12-30 21:01:24 +08:00
parent b9771db163
commit d24ddaacaa
19 changed files with 2252 additions and 34 deletions

View File

@@ -15,8 +15,6 @@ from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from transformers import AutoModelForCausalLM
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
@@ -26,6 +24,10 @@ class ModelManager:
self.model_path = {}
self.textual_inversion_dict = {}
def is_RIFE(self, state_dict):
param_name = "block_tea.convblock3.0.1.weight"
return param_name in state_dict or ("module." + param_name) in state_dict
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
@@ -119,6 +121,7 @@ class ModelManager:
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
from transformers import AutoModelForCausalLM
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
@@ -126,6 +129,15 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_RIFE(self, state_dict, file_path=""):
component = "RIFE"
from ..extensions.RIFE import IFNet
model = IFNet().eval()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -163,6 +175,8 @@ class ModelManager:
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
def load_models(self, file_path_list):
for file_path in file_path_list: