Support WanToDance (#1361)

* support wantodance

* update docs

* bugfix
This commit is contained in:
Zhongjie Duan
2026-03-20 16:40:35 +08:00
committed by GitHub
parent ba0626e38f
commit 52ba5d414e
22 changed files with 1210 additions and 13 deletions

View File

@@ -307,6 +307,13 @@ wan_series = [
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}
},
]
flux_series = [

View File

@@ -6,6 +6,7 @@ from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer
try:
import flash_attn_interface
@@ -283,6 +284,57 @@ class Head(nn.Module):
return x
def wantodance_torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = wantodance_torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
class WanToDanceInjector(nn.Module):
def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]):
super().__init__()
self.injected_block_id = {}
injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, DiTBlock):
for inject_id in inject_layer:
if f'root.transformer_blocks.{inject_id}' == mod_name:
self.injected_block_id[inject_id] = injector_id
injector_id += 1
self.injector = nn.ModuleList(
[
CrossAttention(
dim=dim,
num_heads=num_heads,
)
for _ in range(injector_id)
]
)
self.injector_pre_norm_feat = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
self.injector_pre_norm_vec = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
class WanModel(torch.nn.Module):
def __init__(
self,
@@ -305,6 +357,13 @@ class WanModel(torch.nn.Module):
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
super().__init__()
self.dim = dim
@@ -337,7 +396,12 @@ class WanModel(torch.nn.Module):
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if wantodance_enable_dynamicfps or wantodance_enable_unimodel:
end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350
self.freqs = precompute_freqs_cis_3d(head_dim, end=end)
else:
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
@@ -350,8 +414,83 @@ class WanModel(torch.nn.Module):
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface,
wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel)
def prepare_wantodance(
self,
in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
if wantodance_enable_music_inject:
all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks")
self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers)
if wantodance_enable_refimage:
self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_refface:
self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel:
music_feature_dim = 35
ff_size = 1024
dropout = 0.1
latent_dim = 256
nhead = 4
activation = F.gelu
rotary = WanToDanceRotaryEmbedding(dim=latent_dim)
self.music_projection = nn.Linear(music_feature_dim, latent_dim)
self.music_encoder = nn.Sequential()
for _ in range(2):
self.music_encoder.append(
WanToDanceMusicEncoderLayer(
d_model=latent_dim,
nhead=nhead,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True,
rotary=rotary,
device='cuda',
)
)
if wantodance_enable_unimodel:
self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
if wantodance_enable_unimodel:
self.head_global = Head(dim, out_dim, patch_size, eps)
self.wantodance_enable_music_inject = wantodance_enable_music_inject
self.wantodance_enable_refimage = wantodance_enable_refimage
self.wantodance_enable_refface = wantodance_enable_refface
self.wantodance_enable_global = wantodance_enable_global
self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps
self.wantodance_enable_unimodel = wantodance_enable_unimodel
def wantodance_after_transformer_block(self, block_idx, hidden_states):
if self.wantodance_enable_music_inject:
if block_idx in self.music_injector.injected_block_id.keys():
audio_attn_id = self.music_injector.injected_block_id[block_idx]
audio_emb = self.merged_audio_emb # b f n c
num_frames = audio_emb.shape[1]
input_hidden_states = hidden_states.clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states = hidden_states + residual_out
return hidden_states
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False):
if enable_wantodance_global:
x = self.patch_embedding_global(x)
else:
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]

View File

@@ -1247,6 +1247,22 @@ class WanVideoVAE(nn.Module):
return videos
def encode_framewise(self, videos, device):
hidden_states = []
for i in range(videos.shape[2]):
hidden_states.append(self.single_encode(videos[:, :, i:i+1], device))
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def decode_framewise(self, hidden_states, device):
video = []
for i in range(hidden_states.shape[2]):
video.append(self.single_decode(hidden_states[:, :, i:i+1], device))
video = torch.concat(video, dim=2)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()

View File

@@ -0,0 +1,209 @@
from inspect import isfunction
from math import log, pi
import torch
from einops import rearrange, repeat
from torch import einsum, nn
from typing import Any, Callable, List, Optional, Union
from torch import Tensor
import torch.nn.functional as F
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(freqs, t, start_index=0):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t_left, t, t_right), dim=-1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
if exists(freq_ranges):
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
rotations = rearrange(rotations, "... r f -> ... (r f)")
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
return apply_rotary_emb(rotations, t, start_index=start_index)
# classes
class WanToDanceRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
learned_freq=False,
):
super().__init__()
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
self.cache = dict()
if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.register_buffer("freqs", freqs, persistent=False)
def rotate_queries_or_keys(self, t, seq_dim=-2):
device = t.device
seq_len = t.shape[seq_dim]
freqs = self.forward(
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
)
return apply_rotary_emb(freqs, t)
def forward(self, t, cache_key=None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if isfunction(t):
t = t()
# freqs = self.freqs
freqs = self.freqs.to(t.device)
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
class WanToDanceMusicEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = True,
device=None,
dtype=None,
rotary=None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.rotary = rotary
self.use_rotary = rotary is not None
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tensor:
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
self.norm1.to(device=x.device)
self.norm2.to(device=x.device)
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x

View File

@@ -75,6 +75,9 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
WanVideoUnit_LongCatVideo(),
WanVideoUnit_WanToDance_ProcessInputs(),
WanVideoUnit_WanToDance_RefImageEmbedder(),
WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),
]
self.post_units = [
WanVideoPostUnit_S2V(),
@@ -244,6 +247,13 @@ class WanVideoPipeline(BasePipeline):
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# WanToDance
wantodance_music_path: Optional[str] = None,
wantodance_reference_image: Optional[Image.Image] = None,
wantodance_fps: Optional[float] = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None,
wantodance_keyframes_mask: Optional[list[int]] = None,
framewise_decoding: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
@@ -280,6 +290,9 @@ class WanVideoPipeline(BasePipeline):
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
"vap_video": vap_video,
"wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps,
"wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask,
"framewise_decoding": framewise_decoding,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -325,7 +338,10 @@ class WanVideoPipeline(BasePipeline):
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if framewise_decoding:
video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device)
else:
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
@@ -371,17 +387,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if framewise_decoding:
input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)
else:
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
if not isinstance(vace_reference_image, list):
vace_reference_image = [vace_reference_image]
@@ -1018,6 +1037,111 @@ class WanVideoUnit_LongCatVideo(PipelineUnit):
return {"longcat_latents": longcat_latents}
class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
)
def get_music_base_feature(self, music_path, fps=30):
import librosa
hop_length = 512
sr = fps * hop_length
data, sr = librosa.load(music_path, sr=sr)
sr = 22050
envelope = librosa.onset.onset_strength(y=data, sr=sr)
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T
chroma = librosa.feature.chroma_cens(
y=data, sr=sr, hop_length=hop_length, n_chroma=12
).T
peak_idxs = librosa.onset.onset_detect(
onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length
)
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
peak_onehot[peak_idxs] = 1.0
start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]
_, beat_idxs = librosa.beat.beat_track(
onset_envelope=envelope,
sr=sr,
hop_length=hop_length,
start_bpm=start_bpm,
tightness=100,
)
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
beat_onehot[beat_idxs] = 1.0
audio_feature = np.concatenate(
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
axis=-1,
)
return torch.from_numpy(audio_feature)
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if pipe.dit.wantodance_enable_global:
inputs_nega["skip_9th_layer"] = True
if inputs_shared.get("wantodance_music_path", None) is not None:
inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("wantodance_refimage_feature",),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_reference_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(wantodance_reference_image, list):
wantodance_reference_image = wantodance_reference_image[0]
image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1
refimage_feature = pipe.image_encoder.encode_image([image])
refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"wantodance_refimage_feature": refimage_feature}
class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("clip_feature", "y"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_keyframes is None:
return {}
wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)
pipe.load_models_to_device(self.onload_model_names)
images = []
for input_image in wantodance_keyframes:
input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
images.append(input_image)
clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入
msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1
images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w
images = torch.concat(images, dim=1)
vae_input = images
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
@@ -1123,6 +1247,22 @@ class TemporalTiler_BCTHW:
return value
def wantodance_get_single_freqs(freqs, frame_num, fps):
total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)
interval_frame = 30.0 / (fps + 1e-6)
freqs_0 = freqs[:total_frame]
freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)
freqs_new[0] = freqs_0[0]
freqs_new[-1] = freqs_0[total_frame - 1]
for i in range(1, frame_num-1):
pos = i * interval_frame
low_idx = int(pos)
high_idx = min(low_idx + 1, total_frame - 1)
weight_high = pos - low_idx
weight_low = 1.0 - weight_high
freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high
return freqs_new
def model_fn_wan_video(
dit: WanModel,
@@ -1158,6 +1298,10 @@ def model_fn_wan_video(
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
wantodance_refimage_feature = None,
wantodance_fps: float = 30.0,
music_feature = None,
skip_9th_layer: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1255,7 +1399,10 @@ def model_fn_wan_video(
context = torch.cat([clip_embdding, context], dim=1)
# Camera control
x = dit.patchify(x, control_camera_latents_input)
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:
x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)
else:
x = dit.patchify(x, control_camera_latents_input)
# Animate
if pose_latents is not None and face_pixel_values is not None:
@@ -1310,7 +1457,61 @@ def model_fn_wan_video(
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
# WanToDance
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
if wantodance_refimage_feature is not None:
refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)
context = torch.cat([refimage_feature_embedding, context], dim=1)
if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30:
freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)
freqs = torch.cat([
freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:
if use_unified_sequence_parallel:
length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
music_feature = music_feature[:length]
music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
if not dit.training:
dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation
music_feature = music_feature.to(x.device, dtype=x.dtype)
music_feature = dit.music_projection(music_feature)
music_feature = dit.music_encoder(music_feature)
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
music_feature = get_sp_group().all_gather(music_feature, dim=1)
music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]
N = 149
M = 4800
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]
if music_feature is not None:
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
music_feature = music_feature.to(x.device, dtype=x.dtype)
interp_mode = 'bilinear'
if interp_mode == 'bilinear':
frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21
context_shape_end = context.shape[2] ## 14B 5120
music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]
if use_unified_sequence_parallel:
N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
else:
N = frame_num * 8
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]
if use_unified_sequence_parallel:
dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
else:
dit.merged_audio_emb = music_feature
else:
dit.merged_audio_emb = music_feature
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
@@ -1326,8 +1527,12 @@ def model_fn_wan_video(
return vap(block, *inputs)
return custom_forward
# Block
for block_id, block in enumerate(dit.blocks):
# Block
if skip_9th_layer:
# This is only used in WanToDance
if block_id == 9:
continue
if vap is not None and block_id in vap.mot_layers_mapping:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
@@ -1364,10 +1569,18 @@ def model_fn_wan_video(
# Animate
if pose_latents is not None and face_pixel_values is not None:
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
# WanToDance
if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject:
x = dit.wantodance_after_transformer_block(block_id, x)
if tea_cache is not None:
tea_cache.store(x)
x = dit.head(x, t)
if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:
x = dit.head_global(x, t)
else:
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)