support wan-flf2v

This commit is contained in:
Artiprocher
2025-04-17 14:47:55 +08:00
parent e9e24b8cf1
commit 553b341f5f
5 changed files with 86 additions and 2 deletions

View File

@@ -126,6 +126,7 @@ model_loader_configs = [
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),

View File

@@ -223,7 +223,7 @@ class DiTBlock(nn.Module):
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
@@ -232,8 +232,13 @@ class MLP(torch.nn.Module):
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
@@ -266,6 +271,7 @@ class WanModel(torch.nn.Module):
num_heads: int,
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
):
super().__init__()
self.dim = dim
@@ -296,7 +302,8 @@ class WanModel(torch.nn.Module):
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
self.has_image_pos_emb = has_image_pos_emb
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
@@ -552,6 +559,21 @@ class WanModelStateDictConverter:
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_image_pos_emb": True
}
else:
config = {}
return state_dict, config

View File

@@ -211,6 +211,8 @@ class WanVideoPipeline(BasePipeline):
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
if self.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)

View File

@@ -18,6 +18,7 @@ pip install -e .
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)|
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
@@ -110,6 +111,12 @@ https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
14B first-last-frame-to-video
|First frame|Last frame|Video|
|-|-|-|
|![Image](https://github.com/user-attachments/assets/b0d8225b-aee0-4129-b8e5-58c8523221a6)|![Image](https://github.com/user-attachments/assets/2f0c9bc5-07e2-45fa-8320-53d63a4fd203)|https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1|
## Train
We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA:

View File

@@ -0,0 +1,52 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("Wan-AI/Wan2.1-FLF2V-14B-720P", local_dir="models/Wan-AI/Wan2.1-FLF2V-14B-720P")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
["models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
torch_dtype=torch.float32, # Image Encoder is loaded with float32
)
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00001-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00002-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00003-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00004-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00005-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00006-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00007-of-00007.safetensors",
],
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"]
)
# First and last frame to video
video = pipe(
prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=30,
input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)),
end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)),
height=960, width=960,
seed=1, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)