mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
svd
This commit is contained in:
@@ -16,6 +16,11 @@ from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -25,6 +30,10 @@ class ModelManager:
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_stable_video_diffusion(self, state_dict):
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_RIFE(self, state_dict):
|
||||
param_name = "block_tea.convblock3.0.1.weight"
|
||||
return param_name in state_dict or ("module." + param_name) in state_dict
|
||||
@@ -60,6 +69,21 @@ class ModelManager:
|
||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||
return param_name in state_dict and len(state_dict) == 254
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
"unet": SVDUNet,
|
||||
"vae_decoder": SVDVAEDecoder,
|
||||
"vae_encoder": SVDVAEEncoder,
|
||||
}
|
||||
if components is None:
|
||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
@@ -190,7 +214,9 @@ class ModelManager:
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_animatediff(state_dict):
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
|
||||
Reference in New Issue
Block a user