wan-fun-v1.1 reference control

This commit is contained in:
Artiprocher
2025-04-30 11:38:17 +08:00
parent ef2a7abad4
commit 3edf3583b1
4 changed files with 105 additions and 2 deletions

View File

@@ -272,6 +272,7 @@ class WanModel(torch.nn.Module):
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
):
super().__init__()
self.dim = dim
@@ -303,7 +304,10 @@ class WanModel(torch.nn.Module):
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
@@ -532,6 +536,7 @@ class WanModelStateDictConverter:
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
# 1.3B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
@@ -546,6 +551,7 @@ class WanModelStateDictConverter:
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
# 14B PAI control
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
@@ -574,6 +580,38 @@ class WanModelStateDictConverter:
"eps": 1e-6,
"has_image_pos_emb": True
}
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
# 1.3B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": True
}
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
# 14B PAI control v1.1
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_ref_conv": True
}
else:
config = {}
return state_dict, config