mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
116 lines
4.9 KiB
Python
116 lines
4.9 KiB
Python
from .attention import Attention
|
|
from .svd_unet import get_timestep_embedding
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
|
|
class ExVideoMotionBlock(torch.nn.Module):
|
|
|
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
|
|
super().__init__()
|
|
|
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
|
|
self.positional_embedding = torch.nn.Parameter(emb)
|
|
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
|
|
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
|
|
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
|
|
|
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
|
if frame_id < max_id:
|
|
position_id = frame_id
|
|
else:
|
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
|
if position_id < repeat_length:
|
|
position_id = max_id - 2 - position_id
|
|
else:
|
|
position_id = max_id - 2 * repeat_length + position_id
|
|
return position_id
|
|
|
|
def positional_ids(self, num_frames):
|
|
max_id = self.positional_embedding.shape[0]
|
|
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
|
return positional_ids
|
|
|
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
|
|
batch, inner_dim, height, width = hidden_states.shape
|
|
residual = hidden_states
|
|
|
|
pos_emb = self.positional_ids(batch // batch_size)
|
|
pos_emb = self.positional_embedding[pos_emb]
|
|
pos_emb = pos_emb.repeat(batch_size)
|
|
hidden_states = hidden_states + pos_emb
|
|
if self.positional_conv is not None:
|
|
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
|
|
hidden_states = self.positional_conv(hidden_states)
|
|
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
|
|
else:
|
|
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
|
|
|
|
for norm, attn in zip(self.norms, self.attns):
|
|
norm_hidden_states = norm(hidden_states)
|
|
attn_output = attn(norm_hidden_states)
|
|
hidden_states = hidden_states + attn_output
|
|
|
|
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
|
|
hidden_states = hidden_states + residual
|
|
return hidden_states, time_emb, text_emb, res_stack
|
|
|
|
|
|
|
|
class ExVideoMotionModel(torch.nn.Module):
|
|
def __init__(self, num_layers=2):
|
|
super().__init__()
|
|
self.motion_modules = torch.nn.ModuleList([
|
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
|
])
|
|
self.call_block_id = {
|
|
1: 0,
|
|
4: 1,
|
|
9: 2,
|
|
12: 3,
|
|
17: 4,
|
|
20: 5,
|
|
24: 6,
|
|
26: 7,
|
|
29: 8,
|
|
32: 9,
|
|
34: 10,
|
|
36: 11,
|
|
40: 12,
|
|
43: 13,
|
|
46: 14,
|
|
50: 15,
|
|
53: 16,
|
|
56: 17,
|
|
60: 18,
|
|
63: 19,
|
|
66: 20
|
|
}
|
|
|
|
def forward(self):
|
|
pass
|
|
|
|
def state_dict_converter(self):
|
|
pass
|