Merge pull request #1034 from modelscope/video_as_prompt

Video as prompt
This commit is contained in:
Zhongjie Duan
2025-11-04 17:32:50 +08:00
committed by GitHub
parent 401d7d74a5
commit 8332ecebb7
13 changed files with 635 additions and 16 deletions

View File

@@ -237,6 +237,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
</details>
@@ -387,6 +388,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## Update History
- **November 4, 2025**: We support [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained on Wan 2.1 and enables motion generation conditioned on reference videos.
- **October 30, 2025**: We support [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which enables text-to-video, image-to-video, and video continuation capabilities. This model adopts Wan's framework for both inference and training in this project.
- **October 27, 2025**: We support [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, further expanding Wan's ecosystem.

View File

@@ -237,6 +237,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
</details>
@@ -403,6 +404,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## 更新历史
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型Wan 模型生态再添一员。

View File

@@ -64,6 +64,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_vace import VaceWanModel
from ..models.wav2vec import WanS2VAudioEncoder
from ..models.wan_video_animate_adapter import WanAnimateAdapter
from ..models.wan_video_mot import MotWanModel
from ..models.step1x_connector import Qwen2Connector
@@ -157,6 +158,7 @@ model_loader_configs = [
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
(None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"),
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
(None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),

View File

@@ -437,6 +437,11 @@ class WanModelStateDictConverter:
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
@@ -454,6 +459,14 @@ class WanModelStateDictConverter:
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
@@ -470,7 +483,7 @@ class WanModelStateDictConverter:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
@@ -488,6 +501,20 @@ class WanModelStateDictConverter:
"cross_attn_norm": True,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict_, config

View File

@@ -0,0 +1,281 @@
import torch
from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP
from .utils import hash_state_dict_keys
import einops
import torch.nn as nn
class MotSelfAttention(SelfAttention):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__(dim, num_heads, eps)
def forward(self, x, freqs, is_before_attn=False):
if is_before_attn:
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
return q, k, v
else:
return self.o(x)
class MotWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id
self.self_attn = MotSelfAttention(dim, num_heads, eps)
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):
# 1. prepare scale parameter
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
scale_params_mot_ref = self.modulation + t_mod_mot.float()
scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)
shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)
# 2. Self-attention
input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)
# original block self-attn
attn1 = wan_block.self_attn
q = attn1.norm_q(attn1.q(input_x))
k = attn1.norm_k(attn1.k(input_x))
v = attn1.v(input_x)
q = rope_apply(q, freqs, attn1.num_heads)
k = rope_apply(k, freqs, attn1.num_heads)
# mot block self-attn
norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)
norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)
norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)
q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)
tmp_hidden_states = flash_attention(
torch.cat([q, q_mot], dim=-2),
torch.cat([k, k_mot], dim=-2),
torch.cat([v, v_mot], dim=-2),
num_heads=attn1.num_heads)
attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)
attn_output = attn1.o(attn_output)
x = wan_block.gate(x, gate_msa, attn_output)
attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)
# gate
attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)
attn_output_mot = attn_output_mot * gate_msa_mot_ref
attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)
x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)
# 3. cross-attention and feed-forward
x = x + wan_block.cross_attn(wan_block.norm3(x), context)
input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)
x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))
x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)
# modulate
norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)
norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)
norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)
input_x_mot = self.ffn(norm_x_mot_ref)
# gate
input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)
input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref
input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)
x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)
return x, x_mot
class MotWanModel(torch.nn.Module):
def __init__(
self,
mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
patch_size=(1, 2, 2),
has_image_input=True,
has_image_pos_emb=False,
dim=5120,
num_heads=40,
ffn_dim=13824,
freq_dim=256,
text_dim=4096,
in_dim=36,
eps=1e-6,
):
super().__init__()
self.mot_layers = mot_layers
self.freq_dim = freq_dim
self.dim = dim
self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}
self.head_dim = dim // num_heads
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)
# mot blocks
self.blocks = torch.nn.ModuleList([
MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
for i in self.mot_layers
])
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
return x
def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):
def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):
# 1d rope precompute
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)
h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
freqs = torch.cat([
f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),
w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1)
return freqs
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id):
block = self.blocks[self.mot_layers_mapping[block_id]]
x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)
return x, x_mot
@staticmethod
def state_dict_converter():
return MotWanModelDictConverter()
class MotWanModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
else:
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
state_dict_ = {}
for name, param in state_dict.items():
name = name.replace("_mot_ref", "")
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
if name.split(".")[1].isdigit():
block_id = int(name.split(".")[1])
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
config = {
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"num_heads": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict_, config

View File

@@ -22,6 +22,7 @@ from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_animate_adapter import WanAnimateAdapter
from ..models.wan_video_mot import MotWanModel
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
@@ -47,9 +48,10 @@ class WanVideoPipeline(BasePipeline):
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.vace2: VaceWanModel = None
self.vap: MotWanModel = None
self.animate_adapter: WanAnimateAdapter = None
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
@@ -69,6 +71,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoPostUnit_AnimatePoseLatents(),
WanVideoPostUnit_AnimateFacePixelValues(),
WanVideoPostUnit_AnimateInpaint(),
WanVideoUnit_VAP(),
WanVideoUnit_UnifiedSequenceParallel(),
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
@@ -392,6 +395,7 @@ class WanVideoPipeline(BasePipeline):
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
vace = model_manager.fetch_model("wan_video_vace", index=2)
pipe.vap = model_manager.fetch_model("wan_video_vap")
if isinstance(vace, list):
pipe.vace, pipe.vace2 = vace
else:
@@ -455,6 +459,10 @@ class WanVideoPipeline(BasePipeline):
animate_face_video: Optional[list[Image.Image]] = None,
animate_inpaint_video: Optional[list[Image.Image]] = None,
animate_mask_video: Optional[list[Image.Image]] = None,
# VAP
vap_video: Optional[list[Image.Image]] = None,
vap_prompt: Optional[str] = " ",
negative_vap_prompt: Optional[str] = " ",
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
@@ -493,10 +501,12 @@ class WanVideoPipeline(BasePipeline):
# Inputs
inputs_posi = {
"prompt": prompt,
"vap_prompt": vap_prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"negative_vap_prompt": negative_vap_prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_shared = {
@@ -516,6 +526,7 @@ class WanVideoPipeline(BasePipeline):
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
"vap_video": vap_video,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -927,6 +938,71 @@ class WanVideoUnit_VACE(PipelineUnit):
else:
return {"vace_context": None, "vace_scale": vace_scale}
class WanVideoUnit_VAP(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("text_encoder", "vae", "image_encoder")
)
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("vap_video") is None:
return inputs_shared, inputs_posi, inputs_nega
else:
# 1. encode vap prompt
pipe.load_models_to_device(["text_encoder"])
vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "")
vap_prompt_emb = pipe.prompter.encode_prompt(vap_prompt, positive=inputs_posi.get('positive',None), device=pipe.device)
negative_vap_prompt_emb = pipe.prompter.encode_prompt(negative_vap_prompt, positive=inputs_nega.get('positive',None), device=pipe.device)
inputs_posi.update({"context_vap":vap_prompt_emb})
inputs_nega.update({"context_vap":negative_vap_prompt_emb})
# 2. prepare vap image clip embedding
pipe.load_models_to_device(["vae", "image_encoder"])
vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image")
num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device)
vap_clip_context = pipe.image_encoder.encode_image([image_vap])
if end_image is not None:
vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
if pipe.dit.has_image_pos_emb:
vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1)
vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
inputs_shared.update({"vap_clip_feature":vap_clip_context})
# 3. prepare vap latents
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
msk[:, -1:] = 1
last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1)
else:
vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
vap_video = pipe.preprocess_video(vap_video)
vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device)
inputs_shared.update({"vap_hidden_state":vap_latent})
pipe.load_models_to_device([])
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
@@ -1285,6 +1361,7 @@ def model_fn_wan_video(
dit: WanModel,
motion_controller: WanMotionControllerModel = None,
vace: VaceWanModel = None,
vap: MotWanModel = None,
animate_adapter: WanAnimateAdapter = None,
latents: torch.Tensor = None,
timestep: torch.Tensor = None,
@@ -1297,6 +1374,9 @@ def model_fn_wan_video(
audio_embeds: Optional[torch.Tensor] = None,
motion_latents: Optional[torch.Tensor] = None,
s2v_pose_latents: Optional[torch.Tensor] = None,
vap_hidden_state = None,
vap_clip_feature = None,
context_vap = None,
drop_motion_frames: bool = True,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
@@ -1406,7 +1486,7 @@ def model_fn_wan_video(
if clip_feature is not None and dit.require_clip_embedding:
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
# Camera control
x = dit.patchify(x, control_camera_latents_input)
@@ -1431,6 +1511,25 @@ def model_fn_wan_video(
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# VAP
if vap is not None:
# hidden state
x_vap = vap_hidden_state
x_vap = vap.patchify(x_vap)
x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous()
# Timestep
clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype)
t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep))
t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim))
# rope
freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device)
# context
vap_clip_embedding = vap.img_emb(vap_clip_feature)
context_vap = vap.text_embedding(context_vap)
context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1)
# TeaCache
if tea_cache is not None:
@@ -1460,23 +1559,45 @@ def model_fn_wan_video(
return module(*inputs)
return custom_forward
def create_custom_forward_vap(block, vap):
def custom_forward(*inputs):
return vap(block, *inputs)
return custom_forward
for block_id, block in enumerate(dit.blocks):
# Block
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
if vap is not None and block_id in vap.mot_layers_mapping:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
)
else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
else:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
else:
x = block(x, context, t_mod, freqs)
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:

View File

@@ -78,7 +78,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./model_inference/krea-realtime-video.py)|[code](./model_training/full/krea-realtime-video.sh)|[code](./model_training/validate_full/krea-realtime-video.py)|[code](./model_training/lora/krea-realtime-video.sh)|[code](./model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./model_inference/LongCat-Video.py)|[code](./model_training/full/LongCat-Video.sh)|[code](./model_training/validate_full/LongCat-Video.py)|[code](./model_training/lora/LongCat-Video.sh)|[code](./model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
## Model Inference

View File

@@ -78,7 +78,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./model_inference/krea-realtime-video.py)|[code](./model_training/full/krea-realtime-video.sh)|[code](./model_training/validate_full/krea-realtime-video.py)|[code](./model_training/lora/krea-realtime-video.sh)|[code](./model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./model_inference/LongCat-Video.py)|[code](./model_training/full/LongCat-Video.sh)|[code](./model_training/validate_full/LongCat-Video.py)|[code](./model_training/lora/LongCat-Video.sh)|[code](./model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
## 模型推理

View File

@@ -0,0 +1,63 @@
import torch
import PIL
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
from typing import List
def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]:
if len(video_frames) == 0:
return []
if mode == "first":
return video_frames[:num]
if mode == "evenly":
import torch as _torch
idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist()
return [video_frames[i] for i in idx]
if mode == "random":
if len(video_frames) <= num:
return video_frames
import random as _random
start = _random.randint(0, len(video_frames) - num)
return video_frames[start:start+num]
return video_frames
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="wanvap/*", local_dir="data/example_video_dataset")
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
image = Image.open(target_image_path).convert("RGB")
ref_video = VideoData(ref_video_path, height=480, width=832)
ref_frames = select_frames(ref_video, num=49, mode="evenly")
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
input_image=image,
seed=42, tiled=True,
height=480, width=832,
num_frames=49,
vap_video=ref_frames,
vap_prompt=vap_prompt,
negative_vap_prompt=negative_prompt,
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,16 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_vap.csv \
--data_file_keys "video,vap_video" \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.vap." \
--output_path "./models/train/Video-As-Prompt-Wan2.1-14B_full" \
--trainable_models "vap" \
--extra_inputs "vap_video,input_image" \
--use_gradient_checkpointing_offload

View File

@@ -0,0 +1,18 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_vap.csv \
--data_file_keys "video,vap_video" \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 10 \
--model_id_with_origin_paths "ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Video-As-Prompt-Wan2.1-14B_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "vap_video,input_image" \
--use_gradient_checkpointing_offload

View File

@@ -0,0 +1,43 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Video-As-Prompt-Wan2.1-14B_full/epoch-1.safetensors")
pipe.vap.load_state_dict(state_dict)
pipe.enable_vram_management()
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
image = Image.open(target_image_path).convert("RGB")
ref_video = VideoData(ref_video_path, height=480, width=832)
ref_frames = [ref_video[i] for i in range(49)]
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
input_image=image,
seed=42, tiled=True,
height=480, width=832,
num_frames=49,
vap_video=ref_frames,
vap_prompt=vap_prompt,
negative_vap_prompt=negative_prompt,
)
save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,42 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Video-As-Prompt-Wan2.1-14B_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
image = Image.open(target_image_path).convert("RGB")
ref_video = VideoData(ref_video_path, height=480, width=832)
ref_frames = [ref_video[i] for i in range(49)]
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
input_image=image,
seed=42, tiled=True,
height=480, width=832,
num_frames=49,
vap_video=ref_frames,
vap_prompt=vap_prompt,
negative_vap_prompt=negative_prompt,
)
save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5)