mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan-flf2v
This commit is contained in:
@@ -223,7 +223,7 @@ class DiTBlock(nn.Module):
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Sequential(
|
||||
nn.LayerNorm(in_dim),
|
||||
@@ -232,8 +232,13 @@ class MLP(torch.nn.Module):
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
)
|
||||
self.has_pos_emb = has_pos_emb
|
||||
if has_pos_emb:
|
||||
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.has_pos_emb:
|
||||
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
@@ -266,6 +271,7 @@ class WanModel(torch.nn.Module):
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -296,7 +302,8 @@ class WanModel(torch.nn.Module):
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
@@ -552,6 +559,21 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
Reference in New Issue
Block a user