Files
DiffSynth-Studio/diffsynth/models/sd_motion.py
2024-07-26 14:35:18 +08:00

274 lines
13 KiB
Python

from .sd_unet import SDUNet, Attention, GEGLU
from .svd_unet import get_timestep_embedding
import torch
from einops import rearrange, repeat
class TemporalTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
super().__init__()
self.add_positional_conv = add_positional_conv
# 1. Self-Attn
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe1 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 2. Cross-Attn
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe2 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 3. Feed-forward
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.act_fn = GEGLU(dim, dim * 4)
self.ff = torch.nn.Linear(dim * 4, dim)
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.pe1.shape[1]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, batch_size=1):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.act_fn(norm_hidden_states)
ff_output = self.ff(ff_output)
hidden_states = ff_output + hidden_states
return hidden_states
class TemporalBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
self.transformer_blocks = torch.nn.ModuleList([
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
add_positional_conv=add_positional_conv
)
for d in range(num_layers)
])
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
batch_size=batch_size
)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class SDMotionModel(torch.nn.Module):
def __init__(self, add_positional_conv=None):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
])
self.call_block_id = {
1: 0,
4: 1,
9: 2,
12: 3,
17: 4,
20: 5,
24: 6,
26: 7,
29: 8,
32: 9,
34: 10,
36: 11,
40: 12,
43: 13,
46: 14,
50: 15,
53: 16,
56: 17,
60: 18,
63: 19,
66: 20
}
def forward(self):
pass
def state_dict_converter(self):
return SDMotionModelStateDictConverter()
class SDMotionModelStateDictConverter:
def __init__(self):
pass
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
for i in range(21):
# Extend positional embedding
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
state_dict[name] = state_dict[name][:, ids]
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
state_dict[name] = state_dict[name][:, ids]
# add post convolution
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
return state_dict
def from_diffusers(self, state_dict, add_positional_conv=None):
rename_dict = {
"norm": "norm",
"proj_in": "proj_in",
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
"proj_out": "proj_out",
}
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
state_dict_ = {}
last_prefix, module_id = "", -1
for name in name_list:
names = name.split(".")
prefix_index = names.index("temporal_transformer") + 1
prefix = ".".join(names[:prefix_index])
if prefix != last_prefix:
last_prefix = prefix
module_id += 1
middle_name = ".".join(names[prefix_index:-1])
suffix = names[-1]
if "pos_encoder" in names:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
else:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
state_dict_[rename] = state_dict[name]
if add_positional_conv is not None:
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
return state_dict_
def from_civitai(self, state_dict, add_positional_conv=None):
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)