support wan2.2 A14B I2V&T2V

This commit is contained in:
mi804
2025-07-25 17:09:53 +08:00
parent 3aed244c6f
commit 9015d08927
6 changed files with 175 additions and 9 deletions

View File

@@ -352,6 +352,7 @@ class WanModel(torch.nn.Module):
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
fused_y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
@@ -365,6 +366,8 @@ class WanModel(torch.nn.Module):
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
if fused_y is not None:
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
x, (f, h, w) = self.patchify(x)
@@ -673,6 +676,7 @@ class WanModelStateDictConverter:
"in_dim_control_adapter": 24,
}
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
# Wan-AI/Wan2.2-TI2V-5B
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
@@ -687,6 +691,21 @@ class WanModelStateDictConverter:
"eps": 1e-6,
"seperated_timestep": True,
}
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
# Wan-AI/Wan2.2-I2V-A14B
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
}
else:
config = {}
return state_dict, config