mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support s2v framepack
This commit is contained in:
@@ -410,7 +410,6 @@ class WanS2VModel(torch.nn.Module):
|
||||
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||
# TODO: refactor dfs
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
@@ -449,7 +448,6 @@ class WanS2VModel(torch.nn.Module):
|
||||
|
||||
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
# inject the motion frames token to the hidden states
|
||||
# TODO: check drop_motion_frames = False
|
||||
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||
if len(mot) > 0:
|
||||
x = torch.cat([x, mot[0]], dim=1)
|
||||
|
||||
@@ -183,6 +183,13 @@ class WanS2VAudioEncoder(torch.nn.Module):
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
|
||||
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
|
||||
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
||||
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
||||
return audio_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VAudioEncoderStateDictConverter()
|
||||
|
||||
@@ -65,6 +65,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
]
|
||||
self.post_units = [
|
||||
WanVideoPostUnit_S2V(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
@@ -391,9 +394,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
input_video: Optional[list[Image.Image]] = None,
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# Speech-to-video
|
||||
input_audio: Optional[str] = None,
|
||||
input_audio: Optional[np.array] = None,
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
audio_sample_rate: Optional[int] = 16000,
|
||||
s2v_pose_video: Optional[list[Image.Image]] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
motion_video: Optional[list[Image.Image]] = None,
|
||||
# ControlNet
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
@@ -462,7 +468,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -498,7 +504,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -909,59 +917,91 @@ class WanVideoUnit_S2V(PipelineUnit):
|
||||
onload_model_names=("audio_encoder", "vae",)
|
||||
)
|
||||
|
||||
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
|
||||
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return {}
|
||||
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False):
|
||||
if audio_embeds is not None:
|
||||
return {"audio_embeds": audio_embeds}
|
||||
pipe.load_models_to_device(["audio_encoder"])
|
||||
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device)
|
||||
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
|
||||
z, fps=16, batch_frames=num_frames - 1, m=0
|
||||
)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype)
|
||||
if len(audio_embed_bucket.shape) == 3:
|
||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
||||
elif len(audio_embed_bucket.shape) == 4:
|
||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||
audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1]
|
||||
return {"audio_input": audio_embed_bucket}
|
||||
audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if return_all:
|
||||
return audio_embeds
|
||||
else:
|
||||
return {"audio_embeds": audio_embeds[0]}
|
||||
|
||||
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
|
||||
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
# TODO: may support input motion latents, which related to `drop_motion_frames = False`
|
||||
motion_frames = 73
|
||||
lat_motion_frames = (motion_frames + 3) // 4 # 19
|
||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||
kwargs = {}
|
||||
if motion_video is not None and len(motion_video) > 0:
|
||||
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
|
||||
motion_latents = pipe.preprocess_video(motion_video)
|
||||
kwargs["drop_motion_frames"] = False
|
||||
else:
|
||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||
kwargs["drop_motion_frames"] = True
|
||||
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"motion_latents": motion_latents}
|
||||
kwargs.update({"motion_latents": motion_latents})
|
||||
return kwargs
|
||||
|
||||
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False):
|
||||
if s2v_pose_latents is not None:
|
||||
return {"s2v_pose_latents": s2v_pose_latents}
|
||||
if s2v_pose_video is None:
|
||||
return {"pose_cond": None}
|
||||
return {"s2v_pose_latents": None}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)
|
||||
# get num_frames-1 frames
|
||||
input_video = input_video[:, :, :num_frames]
|
||||
infer_frames = num_frames - 1
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats]
|
||||
# pad if not enough frames
|
||||
padding_frames = num_frames - input_video.shape[2]
|
||||
padding_frames = infer_frames * num_repeats - input_video.shape[2]
|
||||
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||
# encode to latents
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"pose_cond": input_latents[:,:,1:]}
|
||||
input_videos = input_video.chunk(num_repeats, dim=2)
|
||||
pose_conds = []
|
||||
for r in range(num_repeats):
|
||||
cond = input_videos[r]
|
||||
cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2)
|
||||
cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pose_conds.append(cond_latents[:,:,1:])
|
||||
if return_all:
|
||||
return pose_conds
|
||||
else:
|
||||
return {"s2v_pose_latents": pose_conds[0]}
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width")
|
||||
tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate")
|
||||
s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video")
|
||||
|
||||
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames)
|
||||
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)
|
||||
inputs_posi.update(audio_input_positive)
|
||||
inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]})
|
||||
inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]})
|
||||
|
||||
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride))
|
||||
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride))
|
||||
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video))
|
||||
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents))
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
@staticmethod
|
||||
def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)):
|
||||
assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first."
|
||||
shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames)
|
||||
height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"]
|
||||
unit = WanVideoUnit_S2V()
|
||||
audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True)
|
||||
pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
pose_latents = None if s2v_pose_video is None else pose_latents
|
||||
return audio_embeds, pose_latents, len(audio_embeds)
|
||||
|
||||
|
||||
class WanVideoPostUnit_S2V(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"))
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames):
|
||||
if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames:
|
||||
return {}
|
||||
latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2)
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
@@ -1081,9 +1121,10 @@ def model_fn_wan_video(
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
audio_input: Optional[torch.Tensor] = None,
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
pose_cond: Optional[torch.Tensor] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
drop_motion_frames: bool = True,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
@@ -1122,15 +1163,16 @@ def model_fn_wan_video(
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
# wan2.2 s2v
|
||||
if audio_input is not None:
|
||||
if audio_embeds is not None:
|
||||
return model_fn_wans2v(
|
||||
dit=dit,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
audio_input=audio_input,
|
||||
audio_embeds=audio_embeds,
|
||||
motion_latents=motion_latents,
|
||||
pose_cond=pose_cond,
|
||||
s2v_pose_latents=s2v_pose_latents,
|
||||
drop_motion_frames=drop_motion_frames,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
||||
@@ -1261,9 +1303,10 @@ def model_fn_wans2v(
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_input,
|
||||
audio_embeds,
|
||||
motion_latents,
|
||||
pose_cond,
|
||||
s2v_pose_latents,
|
||||
drop_motion_frames=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_unified_sequence_parallel=False,
|
||||
@@ -1280,11 +1323,11 @@ def model_fn_wans2v(
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input)
|
||||
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds)
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
|
||||
# x and s2v_pose_latents
|
||||
s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents))
|
||||
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
|
||||
|
||||
# reference image
|
||||
@@ -1296,7 +1339,7 @@ def model_fn_wans2v(
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2)
|
||||
|
||||
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user