mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
compatibility update
This commit is contained in:
@@ -15,6 +15,8 @@ from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -24,12 +26,19 @@ class ModelManager:
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
return True
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
@@ -74,7 +83,6 @@ class ModelManager:
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
@@ -109,6 +117,15 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -144,6 +161,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list):
|
||||
for file_path in file_path_list:
|
||||
|
||||
Reference in New Issue
Block a user