from .sd_unet import SDUNet, Attention, GEGLU from .svd_unet import get_timestep_embedding import torch from einops import rearrange, repeat class TemporalTransformerBlock(torch.nn.Module): def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None): super().__init__() self.add_positional_conv = add_positional_conv # 1. Self-Attn emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim) self.pe1 = torch.nn.Parameter(emb) if add_positional_conv: self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect") self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) # 2. Cross-Attn emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim) self.pe2 = torch.nn.Parameter(emb) if add_positional_conv: self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect") self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) # 3. Feed-forward self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.act_fn = GEGLU(dim, dim * 4) self.ff = torch.nn.Linear(dim * 4, dim) def frame_id_to_position_id(self, frame_id, max_id, repeat_length): if frame_id < max_id: position_id = frame_id else: position_id = (frame_id - max_id) % (repeat_length * 2) if position_id < repeat_length: position_id = max_id - 2 - position_id else: position_id = max_id - 2 * repeat_length + position_id return position_id def positional_ids(self, num_frames): max_id = self.pe1.shape[1] positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)]) return positional_ids def forward(self, hidden_states, batch_size=1): # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])] if self.add_positional_conv: norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size) norm_hidden_states = self.positional_conv_1(norm_hidden_states) norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size) attn_output = self.attn1(norm_hidden_states) attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) hidden_states = attn_output + hidden_states # 2. Cross-Attention norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])] if self.add_positional_conv: norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size) norm_hidden_states = self.positional_conv_2(norm_hidden_states) norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size) attn_output = self.attn2(norm_hidden_states) attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) ff_output = self.act_fn(norm_hidden_states) ff_output = self.ff(ff_output) hidden_states = ff_output + hidden_states return hidden_states class TemporalBlock(torch.nn.Module): def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None): super().__init__() inner_dim = num_attention_heads * attention_head_dim self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) self.proj_in = torch.nn.Linear(in_channels, inner_dim) self.transformer_blocks = torch.nn.ModuleList([ TemporalTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, max_position_embeddings=32 if add_positional_conv is None else add_positional_conv, add_positional_conv=add_positional_conv ) for d in range(num_layers) ]) self.proj_out = torch.nn.Linear(inner_dim, in_channels) def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1): batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) for block in self.transformer_blocks: hidden_states = block( hidden_states, batch_size=batch_size ) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states + residual return hidden_states, time_emb, text_emb, res_stack class SDMotionModel(torch.nn.Module): def __init__(self, add_positional_conv=None): super().__init__() self.motion_modules = torch.nn.ModuleList([ TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), ]) self.call_block_id = { 1: 0, 4: 1, 9: 2, 12: 3, 17: 4, 20: 5, 24: 6, 26: 7, 29: 8, 32: 9, 34: 10, 36: 11, 40: 12, 43: 13, 46: 14, 50: 15, 53: 16, 56: 17, 60: 18, 63: 19, 66: 20 } def forward(self): pass def state_dict_converter(self): return SDMotionModelStateDictConverter() class SDMotionModelStateDictConverter: def __init__(self): pass def frame_id_to_position_id(self, frame_id, max_id, repeat_length): if frame_id < max_id: position_id = frame_id else: position_id = (frame_id - max_id) % (repeat_length * 2) if position_id < repeat_length: position_id = max_id - 2 - position_id else: position_id = max_id - 2 * repeat_length + position_id return position_id def process_positional_conv_parameters(self, state_dict, add_positional_conv): ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)] for i in range(21): # Extend positional embedding name = f"motion_modules.{i}.transformer_blocks.0.pe1" state_dict[name] = state_dict[name][:, ids] name = f"motion_modules.{i}.transformer_blocks.0.pe2" state_dict[name] = state_dict[name][:, ids] # add post convolution dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1] name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias" state_dict[name] = torch.zeros((dim,)) name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias" state_dict[name] = torch.zeros((dim,)) name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight" param = torch.zeros((dim, dim, 3)) param[:, :, 1] = torch.eye(dim, dim) state_dict[name] = param name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight" param = torch.zeros((dim, dim, 3)) param[:, :, 1] = torch.eye(dim, dim) state_dict[name] = param return state_dict def from_diffusers(self, state_dict, add_positional_conv=None): rename_dict = { "norm": "norm", "proj_in": "proj_in", "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", "proj_out": "proj_out", } name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) state_dict_ = {} last_prefix, module_id = "", -1 for name in name_list: names = name.split(".") prefix_index = names.index("temporal_transformer") + 1 prefix = ".".join(names[:prefix_index]) if prefix != last_prefix: last_prefix = prefix module_id += 1 middle_name = ".".join(names[prefix_index:-1]) suffix = names[-1] if "pos_encoder" in names: rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) else: rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) state_dict_[rename] = state_dict[name] if add_positional_conv is not None: state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv) return state_dict_ def from_civitai(self, state_dict, add_positional_conv=None): return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)