This commit is contained in:
Artiprocher
2024-05-07 10:47:51 +08:00
parent 32991d8e3e
commit 8fa03aa997

View File

@@ -114,6 +114,8 @@ class SVDVideoPipeline(torch.nn.Module):
self,
input_image=None,
input_video=None,
mask_frames=[],
mask_frame_ids=[],
min_cfg_scale=1.0,
max_cfg_scale=3.0,
denoising_strength=1.0,
@@ -133,11 +135,15 @@ class SVDVideoPipeline(torch.nn.Module):
# Prepare latent tensors
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
if denoising_strength == 1.0:
latents = noise
latents = noise.clone()
else:
latents = self.encode_video_with_vae(input_video)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
# Prepare mask frames
if len(mask_frames) > 0:
mask_latents = self.encode_video_with_vae(mask_frames)
# Encode image
image_emb_clip_posi = self.encode_image_with_clip(input_image)
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
@@ -154,6 +160,10 @@ class SVDVideoPipeline(torch.nn.Module):
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Mask frames
for frame_id, mask_frame_id in enumerate(mask_frame_ids):
latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
# Fetch model output
noise_pred = self.calculate_noise_pred(
latents, timestep, add_time_id, cfg_scales,