mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
@@ -75,6 +75,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
WanVideoUnit_LongCatVideo(),
|
||||
WanVideoUnit_WanToDance_ProcessInputs(),
|
||||
WanVideoUnit_WanToDance_RefImageEmbedder(),
|
||||
WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),
|
||||
]
|
||||
self.post_units = [
|
||||
WanVideoPostUnit_S2V(),
|
||||
@@ -244,6 +247,13 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Teacache
|
||||
tea_cache_l1_thresh: Optional[float] = None,
|
||||
tea_cache_model_id: Optional[str] = "",
|
||||
# WanToDance
|
||||
wantodance_music_path: Optional[str] = None,
|
||||
wantodance_reference_image: Optional[Image.Image] = None,
|
||||
wantodance_fps: Optional[float] = 30,
|
||||
wantodance_keyframes: Optional[list[Image.Image]] = None,
|
||||
wantodance_keyframes_mask: Optional[list[int]] = None,
|
||||
framewise_decoding: bool = False,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
|
||||
@@ -280,6 +290,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||
"vap_video": vap_video,
|
||||
"wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps,
|
||||
"wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask,
|
||||
"framewise_decoding": framewise_decoding,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -325,7 +338,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if framewise_decoding:
|
||||
video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device)
|
||||
else:
|
||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if output_type == "quantized":
|
||||
video = self.vae_output_to_video(video)
|
||||
elif output_type == "floatpoint":
|
||||
@@ -371,17 +387,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
|
||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
|
||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):
|
||||
if input_video is None:
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if framewise_decoding:
|
||||
input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)
|
||||
else:
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if vace_reference_image is not None:
|
||||
if not isinstance(vace_reference_image, list):
|
||||
vace_reference_image = [vace_reference_image]
|
||||
@@ -1018,6 +1037,111 @@ class WanVideoUnit_LongCatVideo(PipelineUnit):
|
||||
return {"longcat_latents": longcat_latents}
|
||||
|
||||
|
||||
class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
)
|
||||
|
||||
def get_music_base_feature(self, music_path, fps=30):
|
||||
import librosa
|
||||
hop_length = 512
|
||||
sr = fps * hop_length
|
||||
data, sr = librosa.load(music_path, sr=sr)
|
||||
sr = 22050
|
||||
envelope = librosa.onset.onset_strength(y=data, sr=sr)
|
||||
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T
|
||||
chroma = librosa.feature.chroma_cens(
|
||||
y=data, sr=sr, hop_length=hop_length, n_chroma=12
|
||||
).T
|
||||
peak_idxs = librosa.onset.onset_detect(
|
||||
onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length
|
||||
)
|
||||
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
|
||||
peak_onehot[peak_idxs] = 1.0
|
||||
start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]
|
||||
_, beat_idxs = librosa.beat.beat_track(
|
||||
onset_envelope=envelope,
|
||||
sr=sr,
|
||||
hop_length=hop_length,
|
||||
start_bpm=start_bpm,
|
||||
tightness=100,
|
||||
)
|
||||
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
|
||||
beat_onehot[beat_idxs] = 1.0
|
||||
audio_feature = np.concatenate(
|
||||
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
|
||||
axis=-1,
|
||||
)
|
||||
return torch.from_numpy(audio_feature)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if pipe.dit.wantodance_enable_global:
|
||||
inputs_nega["skip_9th_layer"] = True
|
||||
if inputs_shared.get("wantodance_music_path", None) is not None:
|
||||
inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("wantodance_refimage_feature",),
|
||||
onload_model_names=("image_encoder", "vae")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if wantodance_reference_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if isinstance(wantodance_reference_image, list):
|
||||
wantodance_reference_image = wantodance_reference_image[0]
|
||||
image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1
|
||||
refimage_feature = pipe.image_encoder.encode_image([image])
|
||||
refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"wantodance_refimage_feature": refimage_feature}
|
||||
|
||||
|
||||
class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("clip_feature", "y"),
|
||||
onload_model_names=("image_encoder", "vae")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if wantodance_keyframes is None:
|
||||
return {}
|
||||
wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
images = []
|
||||
for input_image in wantodance_keyframes:
|
||||
input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
images.append(input_image)
|
||||
|
||||
clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入
|
||||
msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1
|
||||
|
||||
images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w
|
||||
images = torch.concat(images, dim=1)
|
||||
vae_input = images
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -1123,6 +1247,22 @@ class TemporalTiler_BCTHW:
|
||||
return value
|
||||
|
||||
|
||||
def wantodance_get_single_freqs(freqs, frame_num, fps):
|
||||
total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)
|
||||
interval_frame = 30.0 / (fps + 1e-6)
|
||||
freqs_0 = freqs[:total_frame]
|
||||
freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)
|
||||
freqs_new[0] = freqs_0[0]
|
||||
freqs_new[-1] = freqs_0[total_frame - 1]
|
||||
for i in range(1, frame_num-1):
|
||||
pos = i * interval_frame
|
||||
low_idx = int(pos)
|
||||
high_idx = min(low_idx + 1, total_frame - 1)
|
||||
weight_high = pos - low_idx
|
||||
weight_low = 1.0 - weight_high
|
||||
freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high
|
||||
return freqs_new
|
||||
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
@@ -1158,6 +1298,10 @@ def model_fn_wan_video(
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
control_camera_latents_input = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
wantodance_refimage_feature = None,
|
||||
wantodance_fps: float = 30.0,
|
||||
music_feature = None,
|
||||
skip_9th_layer: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -1255,7 +1399,10 @@ def model_fn_wan_video(
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
# Camera control
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:
|
||||
x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)
|
||||
else:
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
# Animate
|
||||
if pose_latents is not None and face_pixel_values is not None:
|
||||
@@ -1310,7 +1457,61 @@ def model_fn_wan_video(
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||
)
|
||||
|
||||
|
||||
# WanToDance
|
||||
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
||||
if wantodance_refimage_feature is not None:
|
||||
refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)
|
||||
context = torch.cat([refimage_feature_embedding, context], dim=1)
|
||||
if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30:
|
||||
freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)
|
||||
freqs = torch.cat([
|
||||
freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:
|
||||
if use_unified_sequence_parallel:
|
||||
length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
|
||||
music_feature = music_feature[:length]
|
||||
music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
|
||||
if not dit.training:
|
||||
dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation
|
||||
music_feature = music_feature.to(x.device, dtype=x.dtype)
|
||||
music_feature = dit.music_projection(music_feature)
|
||||
music_feature = dit.music_encoder(music_feature)
|
||||
if music_feature.dim() == 2:
|
||||
music_feature = music_feature.unsqueeze(0)
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
music_feature = get_sp_group().all_gather(music_feature, dim=1)
|
||||
music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]
|
||||
N = 149
|
||||
M = 4800
|
||||
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear')
|
||||
music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]
|
||||
if music_feature is not None:
|
||||
if music_feature.dim() == 2:
|
||||
music_feature = music_feature.unsqueeze(0)
|
||||
music_feature = music_feature.to(x.device, dtype=x.dtype)
|
||||
interp_mode = 'bilinear'
|
||||
if interp_mode == 'bilinear':
|
||||
frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21
|
||||
context_shape_end = context.shape[2] ## 14B 5120
|
||||
music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]
|
||||
if use_unified_sequence_parallel:
|
||||
N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
|
||||
else:
|
||||
N = frame_num * 8
|
||||
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear')
|
||||
music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]
|
||||
if use_unified_sequence_parallel:
|
||||
dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
else:
|
||||
dit.merged_audio_emb = music_feature
|
||||
else:
|
||||
dit.merged_audio_emb = music_feature
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
@@ -1326,8 +1527,12 @@ def model_fn_wan_video(
|
||||
return vap(block, *inputs)
|
||||
return custom_forward
|
||||
|
||||
# Block
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
# Block
|
||||
if skip_9th_layer:
|
||||
# This is only used in WanToDance
|
||||
if block_id == 9:
|
||||
continue
|
||||
if vap is not None and block_id in vap.mot_layers_mapping:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
@@ -1364,10 +1569,18 @@ def model_fn_wan_video(
|
||||
# Animate
|
||||
if pose_latents is not None and face_pixel_values is not None:
|
||||
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
||||
|
||||
# WanToDance
|
||||
if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject:
|
||||
x = dit.wantodance_after_transformer_block(block_id, x)
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
x = dit.head(x, t)
|
||||
if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:
|
||||
x = dit.head_global(x, t)
|
||||
else:
|
||||
x = dit.head(x, t)
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
Reference in New Issue
Block a user