Fixed: S2V Long video severe quality downgrade

This commit is contained in:
Junming Chen
2025-12-14 20:30:34 +00:00
parent e316fb717f
commit 127cc9007a
3 changed files with 59 additions and 44 deletions

View File

@@ -27,23 +27,24 @@ def speech_to_video(
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
with torch.no_grad():
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
print(f"Generating {num_repeat} video clips...")
motion_videos = []
motion_video = None
video = []
for r in range(num_repeat):
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
current_clip = pipe(
current_clip_tensor = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt=negative_prompt,
@@ -53,16 +54,22 @@ def speech_to_video(
width=width,
audio_embeds=audio_embeds[r],
s2v_pose_latents=s2v_pose_latents,
motion_video=motion_videos,
motion_video=motion_video,
num_inference_steps=num_inference_steps,
output_type="floatpoint",
)
current_clip = current_clip[-infer_frames:]
# (B, C, T, H, W)
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
if r == 0:
current_clip = current_clip[3:]
overlap_frames_num = min(motion_frames, len(current_clip))
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
video.extend(current_clip)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
else:
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
video.extend(current_clip_quantized)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=10)
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
return video

View File

@@ -27,23 +27,24 @@ def speech_to_video(
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
with torch.no_grad():
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
print(f"Generating {num_repeat} video clips...")
motion_videos = []
motion_video = None
video = []
for r in range(num_repeat):
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
current_clip = pipe(
current_clip_tensor = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt=negative_prompt,
@@ -53,20 +54,24 @@ def speech_to_video(
width=width,
audio_embeds=audio_embeds[r],
s2v_pose_latents=s2v_pose_latents,
motion_video=motion_videos,
motion_video=motion_video,
num_inference_steps=num_inference_steps,
output_type="floatpoint",
)
current_clip = current_clip[-infer_frames:]
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
if r == 0:
current_clip = current_clip[3:]
overlap_frames_num = min(motion_frames, len(current_clip))
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
video.extend(current_clip)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
else:
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
video.extend(current_clip_quantized)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=10)
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
return video
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",