diff --git a/diffsynth/pipelines/stable_video_diffusion.py b/diffsynth/pipelines/stable_video_diffusion.py index 12b6852..a14aa85 100644 --- a/diffsynth/pipelines/stable_video_diffusion.py +++ b/diffsynth/pipelines/stable_video_diffusion.py @@ -114,6 +114,8 @@ class SVDVideoPipeline(torch.nn.Module): self, input_image=None, input_video=None, + mask_frames=[], + mask_frame_ids=[], min_cfg_scale=1.0, max_cfg_scale=3.0, denoising_strength=1.0, @@ -133,11 +135,15 @@ class SVDVideoPipeline(torch.nn.Module): # Prepare latent tensors noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device) if denoising_strength == 1.0: - latents = noise + latents = noise.clone() else: latents = self.encode_video_with_vae(input_video) latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0]) + # Prepare mask frames + if len(mask_frames) > 0: + mask_latents = self.encode_video_with_vae(mask_frames) + # Encode image image_emb_clip_posi = self.encode_image_with_clip(input_image) image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi) @@ -154,6 +160,10 @@ class SVDVideoPipeline(torch.nn.Module): # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Mask frames + for frame_id, mask_frame_id in enumerate(mask_frame_ids): + latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep) + # Fetch model output noise_pred = self.calculate_noise_pred( latents, timestep, add_time_id, cfg_scales,