Support WanToDance (#1361)

* support wantodance

* update docs

* bugfix
This commit is contained in:
Zhongjie Duan
2026-03-20 16:40:35 +08:00
committed by GitHub
parent ba0626e38f
commit 52ba5d414e
22 changed files with 1210 additions and 13 deletions

View File

@@ -6,6 +6,7 @@ from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer
try:
import flash_attn_interface
@@ -283,6 +284,57 @@ class Head(nn.Module):
return x
def wantodance_torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = wantodance_torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
class WanToDanceInjector(nn.Module):
def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]):
super().__init__()
self.injected_block_id = {}
injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, DiTBlock):
for inject_id in inject_layer:
if f'root.transformer_blocks.{inject_id}' == mod_name:
self.injected_block_id[inject_id] = injector_id
injector_id += 1
self.injector = nn.ModuleList(
[
CrossAttention(
dim=dim,
num_heads=num_heads,
)
for _ in range(injector_id)
]
)
self.injector_pre_norm_feat = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
self.injector_pre_norm_vec = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
class WanModel(torch.nn.Module):
def __init__(
self,
@@ -305,6 +357,13 @@ class WanModel(torch.nn.Module):
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
super().__init__()
self.dim = dim
@@ -337,7 +396,12 @@ class WanModel(torch.nn.Module):
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if wantodance_enable_dynamicfps or wantodance_enable_unimodel:
end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350
self.freqs = precompute_freqs_cis_3d(head_dim, end=end)
else:
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
@@ -350,8 +414,83 @@ class WanModel(torch.nn.Module):
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface,
wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel)
def prepare_wantodance(
self,
in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
if wantodance_enable_music_inject:
all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks")
self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers)
if wantodance_enable_refimage:
self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_refface:
self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel:
music_feature_dim = 35
ff_size = 1024
dropout = 0.1
latent_dim = 256
nhead = 4
activation = F.gelu
rotary = WanToDanceRotaryEmbedding(dim=latent_dim)
self.music_projection = nn.Linear(music_feature_dim, latent_dim)
self.music_encoder = nn.Sequential()
for _ in range(2):
self.music_encoder.append(
WanToDanceMusicEncoderLayer(
d_model=latent_dim,
nhead=nhead,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True,
rotary=rotary,
device='cuda',
)
)
if wantodance_enable_unimodel:
self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
if wantodance_enable_unimodel:
self.head_global = Head(dim, out_dim, patch_size, eps)
self.wantodance_enable_music_inject = wantodance_enable_music_inject
self.wantodance_enable_refimage = wantodance_enable_refimage
self.wantodance_enable_refface = wantodance_enable_refface
self.wantodance_enable_global = wantodance_enable_global
self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps
self.wantodance_enable_unimodel = wantodance_enable_unimodel
def wantodance_after_transformer_block(self, block_idx, hidden_states):
if self.wantodance_enable_music_inject:
if block_idx in self.music_injector.injected_block_id.keys():
audio_attn_id = self.music_injector.injected_block_id[block_idx]
audio_emb = self.merged_audio_emb # b f n c
num_frames = audio_emb.shape[1]
input_hidden_states = hidden_states.clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states = hidden_states + residual_out
return hidden_states
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False):
if enable_wantodance_global:
x = self.patch_embedding_global(x)
else:
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]

View File

@@ -1247,6 +1247,22 @@ class WanVideoVAE(nn.Module):
return videos
def encode_framewise(self, videos, device):
hidden_states = []
for i in range(videos.shape[2]):
hidden_states.append(self.single_encode(videos[:, :, i:i+1], device))
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def decode_framewise(self, hidden_states, device):
video = []
for i in range(hidden_states.shape[2]):
video.append(self.single_decode(hidden_states[:, :, i:i+1], device))
video = torch.concat(video, dim=2)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()

View File

@@ -0,0 +1,209 @@
from inspect import isfunction
from math import log, pi
import torch
from einops import rearrange, repeat
from torch import einsum, nn
from typing import Any, Callable, List, Optional, Union
from torch import Tensor
import torch.nn.functional as F
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(freqs, t, start_index=0):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t_left, t, t_right), dim=-1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
if exists(freq_ranges):
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
rotations = rearrange(rotations, "... r f -> ... (r f)")
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
return apply_rotary_emb(rotations, t, start_index=start_index)
# classes
class WanToDanceRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
learned_freq=False,
):
super().__init__()
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
self.cache = dict()
if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.register_buffer("freqs", freqs, persistent=False)
def rotate_queries_or_keys(self, t, seq_dim=-2):
device = t.device
seq_len = t.shape[seq_dim]
freqs = self.forward(
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
)
return apply_rotary_emb(freqs, t)
def forward(self, t, cache_key=None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if isfunction(t):
t = t()
# freqs = self.freqs
freqs = self.freqs.to(t.device)
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
class WanToDanceMusicEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = True,
device=None,
dtype=None,
rotary=None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.rotary = rotary
self.use_rotary = rotary is not None
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tensor:
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
self.norm1.to(device=x.device)
self.norm2.to(device=x.device)
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x