mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
Fixed: S2V Long video severe quality downgrade
This commit is contained in:
@@ -241,6 +241,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tea_cache_model_id: Optional[str] = "",
|
tea_cache_model_id: Optional[str] = "",
|
||||||
# progress_bar
|
# progress_bar
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
|
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
|
||||||
):
|
):
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
@@ -320,9 +321,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
video = self.vae_output_to_video(video)
|
if output_type == "quantized":
|
||||||
|
video = self.vae_output_to_video(video)
|
||||||
|
elif output_type == "floatpoint":
|
||||||
|
pass
|
||||||
self.load_models_to_device([])
|
self.load_models_to_device([])
|
||||||
|
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
@@ -823,9 +826,9 @@ class WanVideoUnit_S2V(PipelineUnit):
|
|||||||
pipe.load_models_to_device(["vae"])
|
pipe.load_models_to_device(["vae"])
|
||||||
motion_frames = 73
|
motion_frames = 73
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if motion_video is not None and len(motion_video) > 0:
|
if motion_video is not None:
|
||||||
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
|
assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}"
|
||||||
motion_latents = pipe.preprocess_video(motion_video)
|
motion_latents = motion_video
|
||||||
kwargs["drop_motion_frames"] = False
|
kwargs["drop_motion_frames"] = False
|
||||||
else:
|
else:
|
||||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|||||||
@@ -27,23 +27,24 @@ def speech_to_video(
|
|||||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||||
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
||||||
|
|
||||||
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
with torch.no_grad():
|
||||||
pipe=pipe,
|
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
||||||
input_audio=input_audio,
|
pipe=pipe,
|
||||||
audio_sample_rate=sample_rate,
|
input_audio=input_audio,
|
||||||
s2v_pose_video=pose_video,
|
audio_sample_rate=sample_rate,
|
||||||
num_frames=infer_frames + 1,
|
s2v_pose_video=pose_video,
|
||||||
height=height,
|
num_frames=infer_frames + 1,
|
||||||
width=width,
|
height=height,
|
||||||
fps=fps,
|
width=width,
|
||||||
)
|
fps=fps,
|
||||||
|
)
|
||||||
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
||||||
print(f"Generating {num_repeat} video clips...")
|
print(f"Generating {num_repeat} video clips...")
|
||||||
motion_videos = []
|
motion_video = None
|
||||||
video = []
|
video = []
|
||||||
for r in range(num_repeat):
|
for r in range(num_repeat):
|
||||||
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
||||||
current_clip = pipe(
|
current_clip_tensor = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
input_image=input_image,
|
input_image=input_image,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
@@ -53,16 +54,22 @@ def speech_to_video(
|
|||||||
width=width,
|
width=width,
|
||||||
audio_embeds=audio_embeds[r],
|
audio_embeds=audio_embeds[r],
|
||||||
s2v_pose_latents=s2v_pose_latents,
|
s2v_pose_latents=s2v_pose_latents,
|
||||||
motion_video=motion_videos,
|
motion_video=motion_video,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
output_type="floatpoint",
|
||||||
)
|
)
|
||||||
current_clip = current_clip[-infer_frames:]
|
# (B, C, T, H, W)
|
||||||
|
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
|
||||||
if r == 0:
|
if r == 0:
|
||||||
current_clip = current_clip[3:]
|
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
|
||||||
overlap_frames_num = min(motion_frames, len(current_clip))
|
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
||||||
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
|
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
|
||||||
video.extend(current_clip)
|
else:
|
||||||
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
||||||
|
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
|
||||||
|
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
|
||||||
|
video.extend(current_clip_quantized)
|
||||||
|
save_video_with_audio(video, save_path, audio_path, fps=16, quality=10)
|
||||||
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|||||||
@@ -27,23 +27,24 @@ def speech_to_video(
|
|||||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||||
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
||||||
|
|
||||||
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
with torch.no_grad():
|
||||||
pipe=pipe,
|
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
||||||
input_audio=input_audio,
|
pipe=pipe,
|
||||||
audio_sample_rate=sample_rate,
|
input_audio=input_audio,
|
||||||
s2v_pose_video=pose_video,
|
audio_sample_rate=sample_rate,
|
||||||
num_frames=infer_frames + 1,
|
s2v_pose_video=pose_video,
|
||||||
height=height,
|
num_frames=infer_frames + 1,
|
||||||
width=width,
|
height=height,
|
||||||
fps=fps,
|
width=width,
|
||||||
)
|
fps=fps,
|
||||||
|
)
|
||||||
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
||||||
print(f"Generating {num_repeat} video clips...")
|
print(f"Generating {num_repeat} video clips...")
|
||||||
motion_videos = []
|
motion_video = None
|
||||||
video = []
|
video = []
|
||||||
for r in range(num_repeat):
|
for r in range(num_repeat):
|
||||||
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
||||||
current_clip = pipe(
|
current_clip_tensor = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
input_image=input_image,
|
input_image=input_image,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
@@ -53,20 +54,24 @@ def speech_to_video(
|
|||||||
width=width,
|
width=width,
|
||||||
audio_embeds=audio_embeds[r],
|
audio_embeds=audio_embeds[r],
|
||||||
s2v_pose_latents=s2v_pose_latents,
|
s2v_pose_latents=s2v_pose_latents,
|
||||||
motion_video=motion_videos,
|
motion_video=motion_video,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
output_type="floatpoint",
|
||||||
)
|
)
|
||||||
current_clip = current_clip[-infer_frames:]
|
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
|
||||||
if r == 0:
|
if r == 0:
|
||||||
current_clip = current_clip[3:]
|
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
|
||||||
overlap_frames_num = min(motion_frames, len(current_clip))
|
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
||||||
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
|
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
|
||||||
video.extend(current_clip)
|
else:
|
||||||
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
||||||
|
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
|
||||||
|
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
|
||||||
|
video.extend(current_clip_quantized)
|
||||||
|
save_video_with_audio(video, save_path, audio_path, fps=16, quality=10)
|
||||||
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": "disk",
|
"offload_dtype": "disk",
|
||||||
"offload_device": "disk",
|
"offload_device": "disk",
|
||||||
|
|||||||
Reference in New Issue
Block a user