mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
v1.2
This commit is contained in:
@@ -15,8 +15,6 @@ from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -26,6 +24,10 @@ class ModelManager:
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_RIFE(self, state_dict):
|
||||
param_name = "block_tea.convblock3.0.1.weight"
|
||||
return param_name in state_dict or ("module." + param_name) in state_dict
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
@@ -119,6 +121,7 @@ class ModelManager:
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
@@ -126,6 +129,15 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_RIFE(self, state_dict, file_path=""):
|
||||
component = "RIFE"
|
||||
from ..extensions.RIFE import IFNet
|
||||
model = IFNet().eval()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -163,6 +175,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
elif self.is_RIFE(state_dict):
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list):
|
||||
for file_path in file_path_list:
|
||||
|
||||
Reference in New Issue
Block a user