From 127cc9007a858a23c293f2170dfdce7c890f2ab6 Mon Sep 17 00:00:00 2001 From: Junming Chen <72211694+Leoooo333@users.noreply.github.com> Date: Sun, 14 Dec 2025 20:30:34 +0000 Subject: [PATCH 1/2] Fixed: S2V Long video severe quality downgrade --- diffsynth/pipelines/wan_video.py | 13 +++--- .../Wan2.2-S2V-14B_multi_clips.py | 45 +++++++++++-------- .../Wan2.2-S2V-14B_multi_clips.py | 45 ++++++++++--------- 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index fa43db1..6e73251 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -241,6 +241,7 @@ class WanVideoPipeline(BasePipeline): tea_cache_model_id: Optional[str] = "", # progress_bar progress_bar_cmd=tqdm, + output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) @@ -320,9 +321,11 @@ class WanVideoPipeline(BasePipeline): # Decode self.load_models_to_device(['vae']) video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - video = self.vae_output_to_video(video) + if output_type == "quantized": + video = self.vae_output_to_video(video) + elif output_type == "floatpoint": + pass self.load_models_to_device([]) - return video @@ -823,9 +826,9 @@ class WanVideoUnit_S2V(PipelineUnit): pipe.load_models_to_device(["vae"]) motion_frames = 73 kwargs = {} - if motion_video is not None and len(motion_video) > 0: - assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" - motion_latents = pipe.preprocess_video(motion_video) + if motion_video is not None: + assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}" + motion_latents = motion_video kwargs["drop_motion_frames"] = False else: motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py index 35d42ba..a98486f 100644 --- a/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py @@ -27,23 +27,24 @@ def speech_to_video( # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None - audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( - pipe=pipe, - input_audio=input_audio, - audio_sample_rate=sample_rate, - s2v_pose_video=pose_video, - num_frames=infer_frames + 1, - height=height, - width=width, - fps=fps, - ) + with torch.no_grad(): + audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( + pipe=pipe, + input_audio=input_audio, + audio_sample_rate=sample_rate, + s2v_pose_video=pose_video, + num_frames=infer_frames + 1, + height=height, + width=width, + fps=fps, + ) num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat print(f"Generating {num_repeat} video clips...") - motion_videos = [] + motion_video = None video = [] for r in range(num_repeat): s2v_pose_latents = pose_latents[r] if pose_latents is not None else None - current_clip = pipe( + current_clip_tensor = pipe( prompt=prompt, input_image=input_image, negative_prompt=negative_prompt, @@ -53,16 +54,22 @@ def speech_to_video( width=width, audio_embeds=audio_embeds[r], s2v_pose_latents=s2v_pose_latents, - motion_video=motion_videos, + motion_video=motion_video, num_inference_steps=num_inference_steps, + output_type="floatpoint", ) - current_clip = current_clip[-infer_frames:] + # (B, C, T, H, W) + current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:] if r == 0: - current_clip = current_clip[3:] - overlap_frames_num = min(motion_frames, len(current_clip)) - motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:] - video.extend(current_clip) - save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) + current_clip_tensor = current_clip_tensor[:,:,3:,:,:] + overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2]) + motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone() + else: + overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2]) + motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2) + current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor) + video.extend(current_clip_quantized) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=10) print(f"processed the {r+1}th clip of total {num_repeat} clips.") return video diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py index c1995cf..8a20cdf 100644 --- a/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py @@ -27,23 +27,24 @@ def speech_to_video( # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None - audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( - pipe=pipe, - input_audio=input_audio, - audio_sample_rate=sample_rate, - s2v_pose_video=pose_video, - num_frames=infer_frames + 1, - height=height, - width=width, - fps=fps, - ) + with torch.no_grad(): + audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( + pipe=pipe, + input_audio=input_audio, + audio_sample_rate=sample_rate, + s2v_pose_video=pose_video, + num_frames=infer_frames + 1, + height=height, + width=width, + fps=fps, + ) num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat print(f"Generating {num_repeat} video clips...") - motion_videos = [] + motion_video = None video = [] for r in range(num_repeat): s2v_pose_latents = pose_latents[r] if pose_latents is not None else None - current_clip = pipe( + current_clip_tensor = pipe( prompt=prompt, input_image=input_image, negative_prompt=negative_prompt, @@ -53,20 +54,24 @@ def speech_to_video( width=width, audio_embeds=audio_embeds[r], s2v_pose_latents=s2v_pose_latents, - motion_video=motion_videos, + motion_video=motion_video, num_inference_steps=num_inference_steps, + output_type="floatpoint", ) - current_clip = current_clip[-infer_frames:] + current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:] if r == 0: - current_clip = current_clip[3:] - overlap_frames_num = min(motion_frames, len(current_clip)) - motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:] - video.extend(current_clip) - save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) + current_clip_tensor = current_clip_tensor[:,:,3:,:,:] + overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2]) + motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone() + else: + overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2]) + motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2) + current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor) + video.extend(current_clip_quantized) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=10) print(f"processed the {r+1}th clip of total {num_repeat} clips.") return video - vram_config = { "offload_dtype": "disk", "offload_device": "disk", From a4d34d9f3d79d14f840d8ef8b815ef67ca12a212 Mon Sep 17 00:00:00 2001 From: Junming Chen <72211694+Leoooo333@users.noreply.github.com> Date: Sun, 14 Dec 2025 20:53:26 +0000 Subject: [PATCH 2/2] Append: set video compress quality as original version. --- examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py | 2 +- .../model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py index a98486f..43da5b6 100644 --- a/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py @@ -69,7 +69,7 @@ def speech_to_video( motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2) current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor) video.extend(current_clip_quantized) - save_video_with_audio(video, save_path, audio_path, fps=16, quality=10) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) print(f"processed the {r+1}th clip of total {num_repeat} clips.") return video diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py index 8a20cdf..e57d65e 100644 --- a/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py @@ -68,7 +68,7 @@ def speech_to_video( motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2) current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor) video.extend(current_clip_quantized) - save_video_with_audio(video, save_path, audio_path, fps=16, quality=10) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) print(f"processed the {r+1}th clip of total {num_repeat} clips.") return video