mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
svd mask
This commit is contained in:
@@ -114,6 +114,8 @@ class SVDVideoPipeline(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_image=None,
|
input_image=None,
|
||||||
input_video=None,
|
input_video=None,
|
||||||
|
mask_frames=[],
|
||||||
|
mask_frame_ids=[],
|
||||||
min_cfg_scale=1.0,
|
min_cfg_scale=1.0,
|
||||||
max_cfg_scale=3.0,
|
max_cfg_scale=3.0,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
@@ -133,11 +135,15 @@ class SVDVideoPipeline(torch.nn.Module):
|
|||||||
# Prepare latent tensors
|
# Prepare latent tensors
|
||||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
|
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
|
||||||
if denoising_strength == 1.0:
|
if denoising_strength == 1.0:
|
||||||
latents = noise
|
latents = noise.clone()
|
||||||
else:
|
else:
|
||||||
latents = self.encode_video_with_vae(input_video)
|
latents = self.encode_video_with_vae(input_video)
|
||||||
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
|
||||||
|
|
||||||
|
# Prepare mask frames
|
||||||
|
if len(mask_frames) > 0:
|
||||||
|
mask_latents = self.encode_video_with_vae(mask_frames)
|
||||||
|
|
||||||
# Encode image
|
# Encode image
|
||||||
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
||||||
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
||||||
@@ -154,6 +160,10 @@ class SVDVideoPipeline(torch.nn.Module):
|
|||||||
# Denoise
|
# Denoise
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
|
||||||
|
# Mask frames
|
||||||
|
for frame_id, mask_frame_id in enumerate(mask_frame_ids):
|
||||||
|
latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
|
||||||
|
|
||||||
# Fetch model output
|
# Fetch model output
|
||||||
noise_pred = self.calculate_noise_pred(
|
noise_pred = self.calculate_noise_pred(
|
||||||
latents, timestep, add_time_id, cfg_scales,
|
latents, timestep, add_time_id, cfg_scales,
|
||||||
|
|||||||
Reference in New Issue
Block a user