This commit is contained in:
Artiprocher
2024-05-05 22:48:38 +08:00
parent cc37860438
commit 0965477750
15 changed files with 2991 additions and 79 deletions

View File

@@ -16,6 +16,11 @@ from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
@@ -25,6 +30,10 @@ class ModelManager:
self.model_path = {}
self.textual_inversion_dict = {}
def is_stable_video_diffusion(self, state_dict):
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
return param_name in state_dict
def is_RIFE(self, state_dict):
param_name = "block_tea.convblock3.0.1.weight"
return param_name in state_dict or ("module." + param_name) in state_dict
@@ -60,6 +69,21 @@ class ModelManager:
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 254
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"image_encoder": SVDImageEncoder,
"unet": SVDUNet,
"vae_decoder": SVDVAEDecoder,
"vae_encoder": SVDVAEEncoder,
}
if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
@@ -190,7 +214,9 @@ class ModelManager:
def load_model(self, file_path, components=None, lora_alphas=[]):
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_animatediff(state_dict):
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
elif self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
elif self.is_controlnet(state_dict):
self.load_controlnet(state_dict, file_path=file_path)