mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
release ExVideo
This commit is contained in:
@@ -110,7 +110,11 @@ class ModelManager:
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||
param_name = "blocks.185.positional_embedding.embeddings"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
"unet": SVDUNet,
|
||||
@@ -120,8 +124,12 @@ class ModelManager:
|
||||
if components is None:
|
||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component == "unet":
|
||||
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
@@ -305,6 +313,16 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||
unet_state_dict = self.model["unet"].state_dict()
|
||||
self.model["unet"].to("cpu")
|
||||
del self.model["unet"]
|
||||
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -370,6 +388,8 @@ class ModelManager:
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
Reference in New Issue
Block a user