mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan-flf2v
This commit is contained in:
@@ -126,6 +126,7 @@ model_loader_configs = [
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
|
||||
@@ -223,7 +223,7 @@ class DiTBlock(nn.Module):
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Sequential(
|
||||
nn.LayerNorm(in_dim),
|
||||
@@ -232,8 +232,13 @@ class MLP(torch.nn.Module):
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
)
|
||||
self.has_pos_emb = has_pos_emb
|
||||
if has_pos_emb:
|
||||
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.has_pos_emb:
|
||||
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
@@ -266,6 +271,7 @@ class WanModel(torch.nn.Module):
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -296,7 +302,8 @@ class WanModel(torch.nn.Module):
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
@@ -552,6 +559,21 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -211,6 +211,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
if end_image is not None:
|
||||
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
if self.dit.has_image_pos_emb:
|
||||
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
Reference in New Issue
Block a user