mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1034 from modelscope/video_as_prompt
Video as prompt
This commit is contained in:
@@ -64,6 +64,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wav2vec import WanS2VAudioEncoder
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..models.wan_video_mot import MotWanModel
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
@@ -157,6 +158,7 @@ model_loader_configs = [
|
||||
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
|
||||
@@ -437,6 +437,11 @@ class WanModelStateDictConverter:
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
@@ -454,6 +459,14 @@ class WanModelStateDictConverter:
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
@@ -470,7 +483,7 @@ class WanModelStateDictConverter:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
@@ -488,6 +501,20 @@ class WanModelStateDictConverter:
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
281
diffsynth/models/wan_video_mot.py
Normal file
281
diffsynth/models/wan_video_mot.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP
|
||||
from .utils import hash_state_dict_keys
|
||||
import einops
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MotSelfAttention(SelfAttention):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__(dim, num_heads, eps)
|
||||
def forward(self, x, freqs, is_before_attn=False):
|
||||
if is_before_attn:
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
return q, k, v
|
||||
else:
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class MotWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||
self.block_id = block_id
|
||||
|
||||
self.self_attn = MotSelfAttention(dim, num_heads, eps)
|
||||
|
||||
|
||||
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):
|
||||
|
||||
# 1. prepare scale parameter
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
|
||||
scale_params_mot_ref = self.modulation + t_mod_mot.float()
|
||||
scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)
|
||||
shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)
|
||||
|
||||
# 2. Self-attention
|
||||
input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)
|
||||
# original block self-attn
|
||||
attn1 = wan_block.self_attn
|
||||
q = attn1.norm_q(attn1.q(input_x))
|
||||
k = attn1.norm_k(attn1.k(input_x))
|
||||
v = attn1.v(input_x)
|
||||
q = rope_apply(q, freqs, attn1.num_heads)
|
||||
k = rope_apply(k, freqs, attn1.num_heads)
|
||||
|
||||
# mot block self-attn
|
||||
norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
||||
norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)
|
||||
norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)
|
||||
q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)
|
||||
|
||||
tmp_hidden_states = flash_attention(
|
||||
torch.cat([q, q_mot], dim=-2),
|
||||
torch.cat([k, k_mot], dim=-2),
|
||||
torch.cat([v, v_mot], dim=-2),
|
||||
num_heads=attn1.num_heads)
|
||||
|
||||
attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)
|
||||
|
||||
attn_output = attn1.o(attn_output)
|
||||
x = wan_block.gate(x, gate_msa, attn_output)
|
||||
|
||||
attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)
|
||||
# gate
|
||||
attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)
|
||||
attn_output_mot = attn_output_mot * gate_msa_mot_ref
|
||||
attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)
|
||||
x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)
|
||||
|
||||
# 3. cross-attention and feed-forward
|
||||
x = x + wan_block.cross_attn(wan_block.norm3(x), context)
|
||||
input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)
|
||||
x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))
|
||||
|
||||
x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)
|
||||
# modulate
|
||||
norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
||||
norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)
|
||||
norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)
|
||||
input_x_mot = self.ffn(norm_x_mot_ref)
|
||||
# gate
|
||||
input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)
|
||||
input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref
|
||||
input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)
|
||||
x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)
|
||||
|
||||
return x, x_mot
|
||||
|
||||
|
||||
class MotWanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
||||
patch_size=(1, 2, 2),
|
||||
has_image_input=True,
|
||||
has_image_pos_emb=False,
|
||||
dim=5120,
|
||||
num_heads=40,
|
||||
ffn_dim=13824,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
in_dim=36,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.mot_layers = mot_layers
|
||||
self.freq_dim = freq_dim
|
||||
self.dim = dim
|
||||
|
||||
self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)
|
||||
|
||||
# mot blocks
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||
for i in self.mot_layers
|
||||
])
|
||||
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
return x
|
||||
|
||||
def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):
|
||||
def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):
|
||||
# 1d rope precompute
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
||||
[: (dim // 2)].double() / dim))
|
||||
freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)
|
||||
h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
||||
w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
||||
|
||||
freqs = torch.cat([
|
||||
f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1)
|
||||
return freqs
|
||||
|
||||
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id):
|
||||
block = self.blocks[self.mot_layers_mapping[block_id]]
|
||||
x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)
|
||||
return x, x_mot
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return MotWanModelDictConverter()
|
||||
|
||||
|
||||
class MotWanModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
|
||||
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
else:
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
|
||||
|
||||
state_dict_ = {}
|
||||
|
||||
for name, param in state_dict.items():
|
||||
name = name.replace("_mot_ref", "")
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
if name.split(".")[1].isdigit():
|
||||
block_id = int(name.split(".")[1])
|
||||
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
|
||||
if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
|
||||
config = {
|
||||
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"num_heads": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..models.wan_video_mot import MotWanModel
|
||||
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
@@ -47,9 +48,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.vace2: VaceWanModel = None
|
||||
self.vap: MotWanModel = None
|
||||
self.animate_adapter: WanAnimateAdapter = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -69,6 +71,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoPostUnit_AnimatePoseLatents(),
|
||||
WanVideoPostUnit_AnimateFacePixelValues(),
|
||||
WanVideoPostUnit_AnimateInpaint(),
|
||||
WanVideoUnit_VAP(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
@@ -392,6 +395,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
||||
pipe.vap = model_manager.fetch_model("wan_video_vap")
|
||||
if isinstance(vace, list):
|
||||
pipe.vace, pipe.vace2 = vace
|
||||
else:
|
||||
@@ -455,6 +459,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
animate_face_video: Optional[list[Image.Image]] = None,
|
||||
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||
# VAP
|
||||
vap_video: Optional[list[Image.Image]] = None,
|
||||
vap_prompt: Optional[str] = " ",
|
||||
negative_vap_prompt: Optional[str] = " ",
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -493,10 +501,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"vap_prompt": vap_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
"negative_vap_prompt": negative_vap_prompt,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||
}
|
||||
inputs_shared = {
|
||||
@@ -516,6 +526,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||
"vap_video": vap_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -927,6 +938,71 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
else:
|
||||
return {"vace_context": None, "vace_scale": vace_scale}
|
||||
|
||||
class WanVideoUnit_VAP(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder", "vae", "image_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("vap_video") is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
else:
|
||||
# 1. encode vap prompt
|
||||
pipe.load_models_to_device(["text_encoder"])
|
||||
vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "")
|
||||
vap_prompt_emb = pipe.prompter.encode_prompt(vap_prompt, positive=inputs_posi.get('positive',None), device=pipe.device)
|
||||
negative_vap_prompt_emb = pipe.prompter.encode_prompt(negative_vap_prompt, positive=inputs_nega.get('positive',None), device=pipe.device)
|
||||
inputs_posi.update({"context_vap":vap_prompt_emb})
|
||||
inputs_nega.update({"context_vap":negative_vap_prompt_emb})
|
||||
# 2. prepare vap image clip embedding
|
||||
pipe.load_models_to_device(["vae", "image_encoder"])
|
||||
vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image")
|
||||
|
||||
num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
|
||||
|
||||
image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device)
|
||||
|
||||
vap_clip_context = pipe.image_encoder.encode_image([image_vap])
|
||||
if end_image is not None:
|
||||
vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
|
||||
if pipe.dit.has_image_pos_emb:
|
||||
vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1)
|
||||
vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inputs_shared.update({"vap_clip_feature":vap_clip_context})
|
||||
|
||||
# 3. prepare vap latents
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
msk[:, -1:] = 1
|
||||
last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1)
|
||||
else:
|
||||
vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
vap_video = pipe.preprocess_video(vap_video)
|
||||
vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inputs_shared.update({"vap_hidden_state":vap_latent})
|
||||
pipe.load_models_to_device([])
|
||||
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||
@@ -1285,6 +1361,7 @@ def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
vap: MotWanModel = None,
|
||||
animate_adapter: WanAnimateAdapter = None,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
@@ -1297,6 +1374,9 @@ def model_fn_wan_video(
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
vap_hidden_state = None,
|
||||
vap_clip_feature = None,
|
||||
context_vap = None,
|
||||
drop_motion_frames: bool = True,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
@@ -1406,7 +1486,7 @@ def model_fn_wan_video(
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
|
||||
# Camera control
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
@@ -1431,6 +1511,25 @@ def model_fn_wan_video(
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# VAP
|
||||
if vap is not None:
|
||||
# hidden state
|
||||
x_vap = vap_hidden_state
|
||||
x_vap = vap.patchify(x_vap)
|
||||
x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous()
|
||||
# Timestep
|
||||
clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype)
|
||||
t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep))
|
||||
t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim))
|
||||
|
||||
# rope
|
||||
freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device)
|
||||
|
||||
# context
|
||||
vap_clip_embedding = vap.img_emb(vap_clip_feature)
|
||||
context_vap = vap.text_embedding(context_vap)
|
||||
context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
@@ -1460,23 +1559,45 @@ def model_fn_wan_video(
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
def create_custom_forward_vap(block, vap):
|
||||
def custom_forward(*inputs):
|
||||
return vap(block, *inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
# Block
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
if vap is not None and block_id in vap.mot_layers_mapping:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward_vap(block, vap),
|
||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
|
||||
else:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
|
||||
Reference in New Issue
Block a user