Update svd_video.py

This commit is contained in:
Qianyi Zhao
2024-10-23 03:27:59 -05:00
committed by GitHub
parent 747572e62c
commit d381c7b186

View File

@@ -49,7 +49,7 @@ class SVDVideoPipeline(BasePipeline):
return image_emb
def encode_image_with_vae(self, image, noise_aug_strength):
def encode_image_with_vae(self, image, noise_aug_strength, seed):
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
image = image + noise_aug_strength * noise
@@ -148,7 +148,7 @@ class SVDVideoPipeline(BasePipeline):
# Encode image
image_emb_clip_posi = self.encode_image_with_clip(input_image)
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
# Prepare classifier-free guidance