mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
281 lines
13 KiB
Python
281 lines
13 KiB
Python
import torch
|
|
from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP
|
|
from .utils import hash_state_dict_keys
|
|
import einops
|
|
import torch.nn as nn
|
|
|
|
|
|
class MotSelfAttention(SelfAttention):
|
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
|
super().__init__(dim, num_heads, eps)
|
|
def forward(self, x, freqs, is_before_attn=False):
|
|
if is_before_attn:
|
|
q = self.norm_q(self.q(x))
|
|
k = self.norm_k(self.k(x))
|
|
v = self.v(x)
|
|
q = rope_apply(q, freqs, self.num_heads)
|
|
k = rope_apply(k, freqs, self.num_heads)
|
|
return q, k, v
|
|
else:
|
|
return self.o(x)
|
|
|
|
|
|
class MotWanAttentionBlock(DiTBlock):
|
|
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
|
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
|
self.block_id = block_id
|
|
|
|
self.self_attn = MotSelfAttention(dim, num_heads, eps)
|
|
|
|
|
|
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):
|
|
|
|
# 1. prepare scale parameter
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
|
|
|
scale_params_mot_ref = self.modulation + t_mod_mot.float()
|
|
scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)
|
|
shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)
|
|
|
|
# 2. Self-attention
|
|
input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)
|
|
# original block self-attn
|
|
attn1 = wan_block.self_attn
|
|
q = attn1.norm_q(attn1.q(input_x))
|
|
k = attn1.norm_k(attn1.k(input_x))
|
|
v = attn1.v(input_x)
|
|
q = rope_apply(q, freqs, attn1.num_heads)
|
|
k = rope_apply(k, freqs, attn1.num_heads)
|
|
|
|
# mot block self-attn
|
|
norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
|
norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)
|
|
norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)
|
|
q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)
|
|
|
|
tmp_hidden_states = flash_attention(
|
|
torch.cat([q, q_mot], dim=-2),
|
|
torch.cat([k, k_mot], dim=-2),
|
|
torch.cat([v, v_mot], dim=-2),
|
|
num_heads=attn1.num_heads)
|
|
|
|
attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)
|
|
|
|
attn_output = attn1.o(attn_output)
|
|
x = wan_block.gate(x, gate_msa, attn_output)
|
|
|
|
attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)
|
|
# gate
|
|
attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)
|
|
attn_output_mot = attn_output_mot * gate_msa_mot_ref
|
|
attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)
|
|
x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)
|
|
|
|
# 3. cross-attention and feed-forward
|
|
x = x + wan_block.cross_attn(wan_block.norm3(x), context)
|
|
input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)
|
|
x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))
|
|
|
|
x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)
|
|
# modulate
|
|
norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
|
norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)
|
|
norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)
|
|
input_x_mot = self.ffn(norm_x_mot_ref)
|
|
# gate
|
|
input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)
|
|
input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref
|
|
input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)
|
|
x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)
|
|
|
|
return x, x_mot
|
|
|
|
|
|
class MotWanModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
|
patch_size=(1, 2, 2),
|
|
has_image_input=True,
|
|
has_image_pos_emb=False,
|
|
dim=5120,
|
|
num_heads=40,
|
|
ffn_dim=13824,
|
|
freq_dim=256,
|
|
text_dim=4096,
|
|
in_dim=36,
|
|
eps=1e-6,
|
|
):
|
|
super().__init__()
|
|
self.mot_layers = mot_layers
|
|
self.freq_dim = freq_dim
|
|
self.dim = dim
|
|
|
|
self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}
|
|
self.head_dim = dim // num_heads
|
|
|
|
self.patch_embedding = nn.Conv3d(
|
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
self.text_embedding = nn.Sequential(
|
|
nn.Linear(text_dim, dim),
|
|
nn.GELU(approximate='tanh'),
|
|
nn.Linear(dim, dim)
|
|
)
|
|
self.time_embedding = nn.Sequential(
|
|
nn.Linear(freq_dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim)
|
|
)
|
|
self.time_projection = nn.Sequential(
|
|
nn.SiLU(), nn.Linear(dim, dim * 6))
|
|
if has_image_input:
|
|
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)
|
|
|
|
# mot blocks
|
|
self.blocks = torch.nn.ModuleList([
|
|
MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
|
for i in self.mot_layers
|
|
])
|
|
|
|
|
|
def patchify(self, x: torch.Tensor):
|
|
x = self.patch_embedding(x)
|
|
return x
|
|
|
|
def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):
|
|
def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):
|
|
# 1d rope precompute
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
|
[: (dim // 2)].double() / dim))
|
|
freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
return freqs_cis
|
|
|
|
f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)
|
|
h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
|
w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
|
|
|
freqs = torch.cat([
|
|
f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
], dim=-1).reshape(f * h * w, 1, -1)
|
|
return freqs
|
|
|
|
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id):
|
|
block = self.blocks[self.mot_layers_mapping[block_id]]
|
|
x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)
|
|
return x, x_mot
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return MotWanModelDictConverter()
|
|
|
|
|
|
class MotWanModelDictConverter:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_diffusers(self, state_dict):
|
|
|
|
rename_dict = {
|
|
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
|
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
|
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
|
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
|
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
|
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
|
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
|
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
|
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
|
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
|
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
|
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
|
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
|
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
|
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
|
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
|
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
|
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
|
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
|
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
|
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
|
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
|
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
|
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
|
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
|
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
|
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
|
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
|
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
|
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
|
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
|
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
|
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
|
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
|
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
|
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
|
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
|
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
|
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
|
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
|
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
|
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
|
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
|
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
|
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
|
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
|
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
|
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
|
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
|
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
|
"patch_embedding.bias": "patch_embedding.bias",
|
|
"patch_embedding.weight": "patch_embedding.weight",
|
|
"scale_shift_table": "head.modulation",
|
|
"proj_out.bias": "head.head.bias",
|
|
"proj_out.weight": "head.head.weight",
|
|
}
|
|
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
|
|
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
|
|
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
|
else:
|
|
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
|
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
|
|
|
|
state_dict_ = {}
|
|
|
|
for name, param in state_dict.items():
|
|
name = name.replace("_mot_ref", "")
|
|
if name in rename_dict:
|
|
state_dict_[rename_dict[name]] = param
|
|
else:
|
|
if name.split(".")[1].isdigit():
|
|
block_id = int(name.split(".")[1])
|
|
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
|
|
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
|
if name_ in rename_dict:
|
|
name_ = rename_dict[name_]
|
|
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
|
state_dict_[name_] = param
|
|
|
|
if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
|
|
config = {
|
|
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
|
"has_image_input": True,
|
|
"patch_size": [1, 2, 2],
|
|
"in_dim": 36,
|
|
"dim": 5120,
|
|
"ffn_dim": 13824,
|
|
"freq_dim": 256,
|
|
"text_dim": 4096,
|
|
"num_heads": 40,
|
|
"eps": 1e-6
|
|
}
|
|
else:
|
|
config = {}
|
|
return state_dict_, config
|
|
|
|
|
|
|