ExVideo for AnimateDiff

This commit is contained in:
Artiprocher
2024-07-26 14:35:18 +08:00
parent f094cae7e9
commit a076adf592
7 changed files with 520 additions and 48 deletions

View File

@@ -10,6 +10,7 @@ import torch, os, json
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
def lets_dance_with_long_video(
@@ -150,6 +151,14 @@ class SDVideoPipeline(torch.nn.Module):
return latents
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad()
def __call__(
self,
@@ -172,6 +181,8 @@ class SDVideoPipeline(torch.nn.Module):
smoother=None,
smoother_progress_ids=[],
vram_limit_level=0,
post_normalize=False,
contrast_enhance_scale=1.0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -226,15 +237,18 @@ class SDVideoPipeline(torch.nn.Module):
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
@@ -250,6 +264,7 @@ class SDVideoPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
output_frames = self.decode_images(latents)
# Post-process