support wan-flf2v

This commit is contained in:
Artiprocher
2025-04-17 14:47:55 +08:00
parent e9e24b8cf1
commit 553b341f5f
5 changed files with 86 additions and 2 deletions

View File

@@ -126,6 +126,7 @@ model_loader_configs = [
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),

View File

@@ -223,7 +223,7 @@ class DiTBlock(nn.Module):
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
@@ -232,8 +232,13 @@ class MLP(torch.nn.Module):
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
@@ -266,6 +271,7 @@ class WanModel(torch.nn.Module):
num_heads: int,
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
):
super().__init__()
self.dim = dim
@@ -296,7 +302,8 @@ class WanModel(torch.nn.Module):
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
self.has_image_pos_emb = has_image_pos_emb
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
@@ -552,6 +559,21 @@ class WanModelStateDictConverter:
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_image_pos_emb": True
}
else:
config = {}
return state_dict, config

View File

@@ -211,6 +211,8 @@ class WanVideoPipeline(BasePipeline):
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
if self.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)