Support WanToDance (#1361)

* support wantodance

* update docs

* bugfix
This commit is contained in:
Zhongjie Duan
2026-03-20 16:40:35 +08:00
committed by GitHub
parent ba0626e38f
commit 52ba5d414e
22 changed files with 1210 additions and 13 deletions

View File

@@ -75,6 +75,9 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
WanVideoUnit_LongCatVideo(),
WanVideoUnit_WanToDance_ProcessInputs(),
WanVideoUnit_WanToDance_RefImageEmbedder(),
WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),
]
self.post_units = [
WanVideoPostUnit_S2V(),
@@ -244,6 +247,13 @@ class WanVideoPipeline(BasePipeline):
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# WanToDance
wantodance_music_path: Optional[str] = None,
wantodance_reference_image: Optional[Image.Image] = None,
wantodance_fps: Optional[float] = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None,
wantodance_keyframes_mask: Optional[list[int]] = None,
framewise_decoding: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
@@ -280,6 +290,9 @@ class WanVideoPipeline(BasePipeline):
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
"vap_video": vap_video,
"wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps,
"wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask,
"framewise_decoding": framewise_decoding,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -325,7 +338,10 @@ class WanVideoPipeline(BasePipeline):
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if framewise_decoding:
video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device)
else:
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
@@ -371,17 +387,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if framewise_decoding:
input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)
else:
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
if not isinstance(vace_reference_image, list):
vace_reference_image = [vace_reference_image]
@@ -1018,6 +1037,111 @@ class WanVideoUnit_LongCatVideo(PipelineUnit):
return {"longcat_latents": longcat_latents}
class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
)
def get_music_base_feature(self, music_path, fps=30):
import librosa
hop_length = 512
sr = fps * hop_length
data, sr = librosa.load(music_path, sr=sr)
sr = 22050
envelope = librosa.onset.onset_strength(y=data, sr=sr)
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T
chroma = librosa.feature.chroma_cens(
y=data, sr=sr, hop_length=hop_length, n_chroma=12
).T
peak_idxs = librosa.onset.onset_detect(
onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length
)
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
peak_onehot[peak_idxs] = 1.0
start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]
_, beat_idxs = librosa.beat.beat_track(
onset_envelope=envelope,
sr=sr,
hop_length=hop_length,
start_bpm=start_bpm,
tightness=100,
)
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
beat_onehot[beat_idxs] = 1.0
audio_feature = np.concatenate(
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
axis=-1,
)
return torch.from_numpy(audio_feature)
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if pipe.dit.wantodance_enable_global:
inputs_nega["skip_9th_layer"] = True
if inputs_shared.get("wantodance_music_path", None) is not None:
inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("wantodance_refimage_feature",),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_reference_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(wantodance_reference_image, list):
wantodance_reference_image = wantodance_reference_image[0]
image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1
refimage_feature = pipe.image_encoder.encode_image([image])
refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"wantodance_refimage_feature": refimage_feature}
class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("clip_feature", "y"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_keyframes is None:
return {}
wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)
pipe.load_models_to_device(self.onload_model_names)
images = []
for input_image in wantodance_keyframes:
input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
images.append(input_image)
clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入
msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1
images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w
images = torch.concat(images, dim=1)
vae_input = images
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
@@ -1123,6 +1247,22 @@ class TemporalTiler_BCTHW:
return value
def wantodance_get_single_freqs(freqs, frame_num, fps):
total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)
interval_frame = 30.0 / (fps + 1e-6)
freqs_0 = freqs[:total_frame]
freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)
freqs_new[0] = freqs_0[0]
freqs_new[-1] = freqs_0[total_frame - 1]
for i in range(1, frame_num-1):
pos = i * interval_frame
low_idx = int(pos)
high_idx = min(low_idx + 1, total_frame - 1)
weight_high = pos - low_idx
weight_low = 1.0 - weight_high
freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high
return freqs_new
def model_fn_wan_video(
dit: WanModel,
@@ -1158,6 +1298,10 @@ def model_fn_wan_video(
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
wantodance_refimage_feature = None,
wantodance_fps: float = 30.0,
music_feature = None,
skip_9th_layer: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1255,7 +1399,10 @@ def model_fn_wan_video(
context = torch.cat([clip_embdding, context], dim=1)
# Camera control
x = dit.patchify(x, control_camera_latents_input)
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:
x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)
else:
x = dit.patchify(x, control_camera_latents_input)
# Animate
if pose_latents is not None and face_pixel_values is not None:
@@ -1310,7 +1457,61 @@ def model_fn_wan_video(
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
# WanToDance
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
if wantodance_refimage_feature is not None:
refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)
context = torch.cat([refimage_feature_embedding, context], dim=1)
if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30:
freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)
freqs = torch.cat([
freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:
if use_unified_sequence_parallel:
length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
music_feature = music_feature[:length]
music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
if not dit.training:
dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation
music_feature = music_feature.to(x.device, dtype=x.dtype)
music_feature = dit.music_projection(music_feature)
music_feature = dit.music_encoder(music_feature)
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
music_feature = get_sp_group().all_gather(music_feature, dim=1)
music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]
N = 149
M = 4800
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]
if music_feature is not None:
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
music_feature = music_feature.to(x.device, dtype=x.dtype)
interp_mode = 'bilinear'
if interp_mode == 'bilinear':
frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21
context_shape_end = context.shape[2] ## 14B 5120
music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]
if use_unified_sequence_parallel:
N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
else:
N = frame_num * 8
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]
if use_unified_sequence_parallel:
dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
else:
dit.merged_audio_emb = music_feature
else:
dit.merged_audio_emb = music_feature
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
@@ -1326,8 +1527,12 @@ def model_fn_wan_video(
return vap(block, *inputs)
return custom_forward
# Block
for block_id, block in enumerate(dit.blocks):
# Block
if skip_9th_layer:
# This is only used in WanToDance
if block_id == 9:
continue
if vap is not None and block_id in vap.mot_layers_mapping:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
@@ -1364,10 +1569,18 @@ def model_fn_wan_video(
# Animate
if pose_latents is not None and face_pixel_values is not None:
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
# WanToDance
if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject:
x = dit.wantodance_after_transformer_block(block_id, x)
if tea_cache is not None:
tea_cache.store(x)
x = dit.head(x, t)
if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:
x = dit.head_global(x, t)
else:
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)