mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
support diffusers format wan and other lora
This commit is contained in:
@@ -737,7 +737,80 @@ class WanModelStateDictConverter:
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
rename_dict = {"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
|
||||
Reference in New Issue
Block a user