From 5cee326f92e9afce4084fcf9358b3bb927a34a6f Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 1 Sep 2025 16:48:46 +0800 Subject: [PATCH] support s2v framepack --- README.md | 2 +- README_zh.md | 2 +- diffsynth/models/wan_video_dit_s2v.py | 2 - diffsynth/models/wav2vec.py | 7 + diffsynth/pipelines/wan_video_new.py | 143 ++++++++++++------ examples/wanvideo/README.md | 2 +- examples/wanvideo/README_zh.md | 2 +- .../Wan2.2-S2V-14B_multi_clips.py | 116 ++++++++++++++ 8 files changed, 220 insertions(+), 56 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py diff --git a/README.md b/README.md index 96b307a..d014bac 100644 --- a/README.md +++ b/README.md @@ -205,7 +205,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training | |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/README_zh.md b/README_zh.md index abc276d..edba5ab 100644 --- a/README_zh.md +++ b/README_zh.md @@ -205,7 +205,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index fa54591..70881e6 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -410,7 +410,6 @@ class WanS2VModel(torch.nn.Module): self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") - # TODO: refactor dfs self.audio_injector = AudioInjector_WAN( all_modules, all_modules_names, @@ -449,7 +448,6 @@ class WanS2VModel(torch.nn.Module): def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): # inject the motion frames token to the hidden states - # TODO: check drop_motion_frames = False mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) if len(mot) > 0: x = torch.cat([x, mot[0]], dim=1) diff --git a/diffsynth/models/wav2vec.py b/diffsynth/models/wav2vec.py index e17bd96..f98b721 100644 --- a/diffsynth/models/wav2vec.py +++ b/diffsynth/models/wav2vec.py @@ -183,6 +183,13 @@ class WanS2VAudioEncoder(torch.nn.Module): return batch_audio_eb, min_batch_num + def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'): + audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device) + audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype) + audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)] + return audio_embeds + @staticmethod def state_dict_converter(): return WanS2VAudioEncoderStateDictConverter() diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index cef7ae8..660a38e 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -65,6 +65,9 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] self.model_fn = model_fn_wan_video @@ -391,9 +394,12 @@ class WanVideoPipeline(BasePipeline): input_video: Optional[list[Image.Image]] = None, denoising_strength: Optional[float] = 1.0, # Speech-to-video - input_audio: Optional[str] = None, + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, audio_sample_rate: Optional[int] = 16000, s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, # ControlNet control_video: Optional[list[Image.Image]] = None, reference_image: Optional[Image.Image] = None, @@ -462,7 +468,7 @@ class WanVideoPipeline(BasePipeline): "motion_bucket_id": motion_bucket_id, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, - "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -498,7 +504,9 @@ class WanVideoPipeline(BasePipeline): # VACE (TODO: remove it) if vace_reference_image is not None: inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] - + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Decode self.load_models_to_device(['vae']) video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) @@ -909,59 +917,91 @@ class WanVideoUnit_S2V(PipelineUnit): onload_model_names=("audio_encoder", "vae",) ) - def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames): - if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None: - return {} + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} pipe.load_models_to_device(["audio_encoder"]) - z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device) - audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps( - z, fps=16, batch_frames=num_frames - 1, m=0 - ) - audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype) - if len(audio_embed_bucket.shape) == 3: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) - elif len(audio_embed_bucket.shape) == 4: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) - audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1] - return {"audio_input": audio_embed_bucket} + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} - def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride): + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): pipe.load_models_to_device(["vae"]) - # TODO: may support input motion latents, which related to `drop_motion_frames = False` motion_frames = 73 - lat_motion_frames = (motion_frames + 3) // 4 # 19 - motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs = {} + if motion_video is not None and len(motion_video) > 0: + assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" + motion_latents = pipe.preprocess_video(motion_video) + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - return {"motion_latents": motion_latents} + kwargs.update({"motion_latents": motion_latents}) + return kwargs - def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride): + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} if s2v_pose_video is None: - return {"pose_cond": None} + return {"s2v_pose_latents": None} pipe.load_models_to_device(["vae"]) - input_video = pipe.preprocess_video(s2v_pose_video) - # get num_frames-1 frames - input_video = input_video[:, :, :num_frames] + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] # pad if not enough frames - padding_frames = num_frames - input_video.shape[2] + padding_frames = infer_frames * num_repeats - input_video.shape[2] input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) - # encode to latents - input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - return {"pose_cond": input_latents[:,:,1:]} + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): - if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None: + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: return inputs_shared, inputs_posi, inputs_nega - input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width") - tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate") + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video") - audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames) + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) inputs_posi.update(audio_input_positive) - inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]}) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) - inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride)) - inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride)) + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) return inputs_shared, inputs_posi, inputs_nega + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): @@ -1081,9 +1121,10 @@ def model_fn_wan_video( reference_latents = None, vace_context = None, vace_scale = 1.0, - audio_input: Optional[torch.Tensor] = None, + audio_embeds: Optional[torch.Tensor] = None, motion_latents: Optional[torch.Tensor] = None, - pose_cond: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + drop_motion_frames: bool = True, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, @@ -1122,15 +1163,16 @@ def model_fn_wan_video( batch_size=2 if cfg_merge else 1 ) # wan2.2 s2v - if audio_input is not None: + if audio_embeds is not None: return model_fn_wans2v( dit=dit, latents=latents, timestep=timestep, context=context, - audio_input=audio_input, + audio_embeds=audio_embeds, motion_latents=motion_latents, - pose_cond=pose_cond, + s2v_pose_latents=s2v_pose_latents, + drop_motion_frames=drop_motion_frames, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, use_unified_sequence_parallel=use_unified_sequence_parallel, @@ -1261,9 +1303,10 @@ def model_fn_wans2v( latents, timestep, context, - audio_input, + audio_embeds, motion_latents, - pose_cond, + s2v_pose_latents, + drop_motion_frames=True, use_gradient_checkpointing_offload=False, use_gradient_checkpointing=False, use_unified_sequence_parallel=False, @@ -1280,11 +1323,11 @@ def model_fn_wans2v( context = dit.text_embedding(context) # audio encode - audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input) + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) - # x and pose_cond - pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) + # x and s2v_pose_latents + s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel # reference image @@ -1296,7 +1339,7 @@ def model_fn_wans2v( # freqs pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) # motion - x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) x = x + dit.trainable_cond_mask(mask).to(x.dtype) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 9a38682..a893c0a 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -48,7 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index c0bd3dc..2ec2c48 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -48,7 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py new file mode 100644 index 0000000..6ffff6d --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py @@ -0,0 +1,116 @@ +import torch +from PIL import Image +import librosa +from diffsynth import VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V +from modelscope import dataset_snapshot_download + + +def speech_to_video( + prompt, + input_image, + audio_path, + negative_prompt="", + num_clip=None, + audio_sample_rate=16000, + pose_video_path=None, + infer_frames=80, + height=448, + width=832, + num_inference_steps=40, + fps=16, # recommend fixing fps as 16 for s2v + motion_frames=73, # hyperparameter of wan2.2-s2v + save_path=None, +): + # s2v audio input, recommend 16kHz sampling rate + input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate) + # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. + pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None + + audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( + pipe=pipe, + input_audio=input_audio, + audio_sample_rate=sample_rate, + s2v_pose_video=pose_video, + num_frames=infer_frames + 1, + height=height, + width=width, + fps=fps, + ) + num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat + print(f"Generating {num_repeat} video clips...") + motion_videos = [] + video = [] + for r in range(num_repeat): + s2v_pose_latents = pose_latents[r] if pose_latents is not None else None + current_clip = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=infer_frames + 1, + height=height, + width=width, + audio_embeds=audio_embeds[r], + s2v_pose_latents=s2v_pose_latents, + motion_video=motion_videos, + num_inference_steps=num_inference_steps, + ) + current_clip = current_clip[-infer_frames:] + if r == 0: + current_clip = current_clip[3:] + overlap_frames_num = min(motion_frames, len(current_clip)) + motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:] + video.extend(current_clip) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) + print(f"processed the {r+1}th clip of total {num_repeat} clips.") + return video + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*", +) + +infer_frames = 80 # 4n +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) + +video_with_audio = speech_to_video( + prompt=prompt, + input_image=input_image, + audio_path='data/example_video_dataset/wans2v/sing.MP3', + negative_prompt=negative_prompt, + pose_video_path='data/example_video_dataset/wans2v/pose.mp4', + save_path="video_with_audio_full.mp4", + infer_frames=infer_frames, + height=height, + width=width, +) +# num_clip means generating only the first n clips with n * infer_frames frames. +video_with_audio_pose = speech_to_video( + prompt=prompt, + input_image=input_image, + audio_path='data/example_video_dataset/wans2v/sing.MP3', + negative_prompt=negative_prompt, + pose_video_path='data/example_video_dataset/wans2v/pose.mp4', + save_path="video_with_audio_pose_clip_2.mp4", + num_clip=2 +)