mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support animatediff on sdxl
This commit is contained in:
@@ -15,6 +15,7 @@ from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
@@ -61,6 +62,10 @@ class ModelManager:
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff_xl(self, state_dict):
|
||||
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_sd_lora(self, state_dict):
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
@@ -153,6 +158,14 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_animatediff_xl(self, state_dict, file_path=""):
|
||||
component = "motion_modules_xl"
|
||||
model = SDXLMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -218,6 +231,8 @@ class ModelManager:
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff_xl(state_dict):
|
||||
self.load_animatediff_xl(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
elif self.is_stabe_diffusion_xl(state_dict):
|
||||
|
||||
103
diffsynth/models/sdxl_motion.py
Normal file
103
diffsynth/models/sdxl_motion.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from .sd_motion import TemporalBlock
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class SDXLMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
])
|
||||
self.call_block_id = {
|
||||
0: 0,
|
||||
2: 1,
|
||||
7: 2,
|
||||
10: 3,
|
||||
15: 4,
|
||||
18: 5,
|
||||
25: 6,
|
||||
28: 7,
|
||||
31: 8,
|
||||
35: 9,
|
||||
38: 10,
|
||||
41: 11,
|
||||
44: 12,
|
||||
46: 13,
|
||||
48: 14,
|
||||
}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
||||
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
||||
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
||||
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
||||
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
||||
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
||||
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
||||
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
||||
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
||||
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
||||
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
||||
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
||||
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
||||
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
||||
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
||||
state_dict_ = {}
|
||||
last_prefix, module_id = "", -1
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
prefix_index = names.index("temporal_transformer") + 1
|
||||
prefix = ".".join(names[:prefix_index])
|
||||
if prefix != last_prefix:
|
||||
last_prefix = prefix
|
||||
module_id += 1
|
||||
middle_name = ".".join(names[prefix_index:-1])
|
||||
suffix = names[-1]
|
||||
if "pos_encoder" in names:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
Reference in New Issue
Block a user