mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
ExVideo for AnimateDiff
This commit is contained in:
@@ -1,20 +1,28 @@
|
||||
from .sd_unet import SDUNet, Attention, GEGLU
|
||||
from .svd_unet import get_timestep_embedding
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class TemporalTransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
|
||||
super().__init__()
|
||||
self.add_positional_conv = add_positional_conv
|
||||
|
||||
# 1. Self-Attn
|
||||
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||
self.pe1 = torch.nn.Parameter(emb)
|
||||
if add_positional_conv:
|
||||
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
# 2. Cross-Attn
|
||||
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||
self.pe2 = torch.nn.Parameter(emb)
|
||||
if add_positional_conv:
|
||||
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
|
||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||
|
||||
|
||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||
if frame_id < max_id:
|
||||
position_id = frame_id
|
||||
else:
|
||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||
if position_id < repeat_length:
|
||||
position_id = max_id - 2 - position_id
|
||||
else:
|
||||
position_id = max_id - 2 * repeat_length + position_id
|
||||
return position_id
|
||||
|
||||
|
||||
def positional_ids(self, num_frames):
|
||||
max_id = self.pe1.shape[1]
|
||||
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||
return positional_ids
|
||||
|
||||
|
||||
def forward(self, hidden_states, batch_size=1):
|
||||
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
||||
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||
if self.add_positional_conv:
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn1(norm_hidden_states)
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
||||
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||
if self.add_positional_conv:
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn2(norm_hidden_states)
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
|
||||
|
||||
class TemporalBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
|
||||
TemporalTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim
|
||||
attention_head_dim,
|
||||
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
|
||||
add_positional_conv=add_positional_conv
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class SDMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, add_positional_conv=None):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||
])
|
||||
self.call_block_id = {
|
||||
1: 0,
|
||||
@@ -152,7 +190,42 @@ class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||
if frame_id < max_id:
|
||||
position_id = frame_id
|
||||
else:
|
||||
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||
if position_id < repeat_length:
|
||||
position_id = max_id - 2 - position_id
|
||||
else:
|
||||
position_id = max_id - 2 * repeat_length + position_id
|
||||
return position_id
|
||||
|
||||
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
|
||||
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
|
||||
for i in range(21):
|
||||
# Extend positional embedding
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
|
||||
state_dict[name] = state_dict[name][:, ids]
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
|
||||
state_dict[name] = state_dict[name][:, ids]
|
||||
# add post convolution
|
||||
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
|
||||
state_dict[name] = torch.zeros((dim,))
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
|
||||
state_dict[name] = torch.zeros((dim,))
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
|
||||
param = torch.zeros((dim, dim, 3))
|
||||
param[:, :, 1] = torch.eye(dim, dim)
|
||||
state_dict[name] = param
|
||||
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
|
||||
param = torch.zeros((dim, dim, 3))
|
||||
param[:, :, 1] = torch.eye(dim, dim)
|
||||
state_dict[name] = param
|
||||
return state_dict
|
||||
|
||||
def from_diffusers(self, state_dict, add_positional_conv=None):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
@@ -192,7 +265,9 @@ class SDMotionModelStateDictConverter:
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
if add_positional_conv is not None:
|
||||
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
def from_civitai(self, state_dict, add_positional_conv=None):
|
||||
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)
|
||||
|
||||
Reference in New Issue
Block a user