ExVideo for AnimateDiff

This commit is contained in:
Artiprocher
2024-07-26 14:35:18 +08:00
parent f094cae7e9
commit a076adf592
7 changed files with 520 additions and 48 deletions

View File

@@ -194,10 +194,10 @@ class ModelManager:
self.model[component].append(model)
self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path=""):
def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
component = "motion_modules"
model = SDMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model = SDMotionModel(add_positional_conv=add_positional_conv)
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path

View File

@@ -1,20 +1,28 @@
from .sd_unet import SDUNet, Attention, GEGLU
from .svd_unet import get_timestep_embedding
import torch
from einops import rearrange, repeat
class TemporalTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
super().__init__()
self.add_positional_conv = add_positional_conv
# 1. Self-Attn
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe1 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 2. Cross-Attn
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe2 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
self.ff = torch.nn.Linear(dim * 4, dim)
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.pe1.shape[1]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, batch_size=1):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
class TemporalBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim
attention_head_dim,
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
add_positional_conv=add_positional_conv
)
for d in range(num_layers)
])
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
class SDMotionModel(torch.nn.Module):
def __init__(self):
def __init__(self, add_positional_conv=None):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 160, 1280, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 80, 640, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
])
self.call_block_id = {
1: 0,
@@ -152,7 +190,42 @@ class SDMotionModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
for i in range(21):
# Extend positional embedding
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
state_dict[name] = state_dict[name][:, ids]
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
state_dict[name] = state_dict[name][:, ids]
# add post convolution
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
return state_dict
def from_diffusers(self, state_dict, add_positional_conv=None):
rename_dict = {
"norm": "norm",
"proj_in": "proj_in",
@@ -192,7 +265,9 @@ class SDMotionModelStateDictConverter:
else:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
state_dict_[rename] = state_dict[name]
if add_positional_conv is not None:
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
def from_civitai(self, state_dict, add_positional_conv=None):
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)

View File

@@ -0,0 +1,115 @@
from .attention import Attention
from .svd_unet import get_timestep_embedding
import torch
from einops import rearrange, repeat
class ExVideoMotionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
super().__init__()
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
self.positional_embedding = torch.nn.Parameter(emb)
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.positional_embedding.shape[0]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
batch, inner_dim, height, width = hidden_states.shape
residual = hidden_states
pos_emb = self.positional_ids(batch // batch_size)
pos_emb = self.positional_embedding[pos_emb]
pos_emb = pos_emb.repeat(batch_size)
hidden_states = hidden_states + pos_emb
if self.positional_conv is not None:
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
hidden_states = self.positional_conv(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
else:
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
for norm, attn in zip(self.norms, self.attns):
norm_hidden_states = norm(hidden_states)
attn_output = attn(norm_hidden_states)
hidden_states = hidden_states + attn_output
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class ExVideoMotionModel(torch.nn.Module):
def __init__(self, num_layers=2):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
])
self.call_block_id = {
1: 0,
4: 1,
9: 2,
12: 3,
17: 4,
20: 5,
24: 6,
26: 7,
29: 8,
32: 9,
34: 10,
36: 11,
40: 12,
43: 13,
46: 14,
50: 15,
53: 16,
56: 17,
60: 18,
63: 19,
66: 20
}
def forward(self):
pass
def state_dict_converter(self):
pass

View File

@@ -10,6 +10,7 @@ import torch, os, json
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
def lets_dance_with_long_video(
@@ -150,6 +151,14 @@ class SDVideoPipeline(torch.nn.Module):
return latents
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad()
def __call__(
self,
@@ -172,6 +181,8 @@ class SDVideoPipeline(torch.nn.Module):
smoother=None,
smoother_progress_ids=[],
vram_limit_level=0,
post_normalize=False,
contrast_enhance_scale=1.0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -226,15 +237,18 @@ class SDVideoPipeline(torch.nn.Module):
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
@@ -250,6 +264,7 @@ class SDVideoPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
output_frames = self.decode_images(latents)
# Post-process

View File

@@ -8,9 +8,9 @@ class SDPrompter(Prompter):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))

View File

@@ -3,12 +3,12 @@ from ..models import ModelManager
import os
def tokenize_long_prompt(tokenizer, prompt):
def tokenize_long_prompt(tokenizer, prompt, max_length=99999999):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
tokenizer.model_max_length = max_length
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids