mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
ExVideo for AnimateDiff
This commit is contained in:
@@ -10,6 +10,7 @@ import torch, os, json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
@@ -150,6 +151,14 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
return latents
|
||||
|
||||
|
||||
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
||||
if post_normalize:
|
||||
mean, std = latents.mean(), latents.std()
|
||||
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
||||
latents = latents * contrast_enhance_scale
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -172,6 +181,8 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
vram_limit_level=0,
|
||||
post_normalize=False,
|
||||
contrast_enhance_scale=1.0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
@@ -226,15 +237,18 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_with_long_video(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# DDIM and smoother
|
||||
if smoother is not None and progress_id in smoother_progress_ids:
|
||||
@@ -250,6 +264,7 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
||||
output_frames = self.decode_images(latents)
|
||||
|
||||
# Post-process
|
||||
|
||||
Reference in New Issue
Block a user