mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
wans2v refactor
This commit is contained in:
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .utils import hash_state_dict_keys
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
@@ -24,22 +24,6 @@ def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
return modules, module_names
|
||||
|
||||
|
||||
def rope_apply(x, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, _ in enumerate(x):
|
||||
s = x.size(1)
|
||||
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
|
||||
freqs_i = freqs[i, :s]
|
||||
# apply rotary embedding
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, s:]])
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).to(x.dtype)
|
||||
|
||||
|
||||
def rope_precompute(x, grid_sizes, freqs, start=None):
|
||||
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
||||
|
||||
@@ -135,11 +119,8 @@ class MotionEncoder_tc(nn.Module):
|
||||
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -358,71 +339,21 @@ class CausalAudioEncoder(nn.Module):
|
||||
return res # b f n dim
|
||||
|
||||
|
||||
class WanS2VSelfAttention(SelfAttention):
|
||||
|
||||
def forward(self, x, freqs):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x)
|
||||
q = rope_apply(q, freqs)
|
||||
k = rope_apply(k, freqs)
|
||||
x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class WanS2VDiTBlock(DiTBlock):
|
||||
|
||||
def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False):
|
||||
super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps)
|
||||
self.self_attn = WanS2VSelfAttention(dim, num_heads, eps)
|
||||
|
||||
def forward(self, x, context, e, freqs):
|
||||
seg_idx = e[1].item()
|
||||
seg_idx = min(max(0, seg_idx), x.size(1))
|
||||
seg_idx = [0, seg_idx, x.size(1)]
|
||||
e = e[0]
|
||||
modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device)
|
||||
e = (modulation + e).chunk(6, dim=1)
|
||||
e = [element.squeeze(1) for element in e]
|
||||
norm_x = self.norm1(x)
|
||||
parts = []
|
||||
for i in range(2):
|
||||
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
|
||||
norm_x = torch.cat(parts, dim=1)
|
||||
# self-attention
|
||||
y = self.self_attn(norm_x, freqs)
|
||||
z = []
|
||||
for i in range(2):
|
||||
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
|
||||
y = torch.cat(z, dim=1)
|
||||
x = x + y
|
||||
|
||||
# cross-attention & ffn function
|
||||
def cross_attn_ffn(x, context, e):
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
norm2_x = self.norm2(x)
|
||||
parts = []
|
||||
for i in range(2):
|
||||
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
|
||||
norm2_x = torch.cat(parts, dim=1)
|
||||
y = self.ffn(norm2_x)
|
||||
z = []
|
||||
for i in range(2):
|
||||
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
|
||||
y = torch.cat(z, dim=1)
|
||||
x = x + y
|
||||
return x
|
||||
|
||||
x = cross_attn_ffn(x, context, e)
|
||||
return x
|
||||
|
||||
|
||||
class S2VHead(Head):
|
||||
|
||||
def forward(self, x, t_mod):
|
||||
t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0]))
|
||||
def forward(self, x, context, t_mod, seq_len_x, freqs):
|
||||
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
|
||||
t_mod = [
|
||||
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
|
||||
for element in t_mod
|
||||
]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
return x
|
||||
|
||||
|
||||
@@ -472,9 +403,9 @@ class WanS2VModel(torch.nn.Module):
|
||||
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
|
||||
self.head = S2VHead(dim, out_dim, patch_size, eps)
|
||||
self.freqs = precompute_freqs_cis_3d(dim // num_heads)
|
||||
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
|
||||
|
||||
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||
@@ -516,17 +447,17 @@ class WanS2VModel(torch.nn.Module):
|
||||
else:
|
||||
return flattern_mot, mot_remb
|
||||
|
||||
def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
# inject the motion frames token to the hidden states
|
||||
# TODO: check drop_motion_frames = False
|
||||
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||
if len(mot) > 0:
|
||||
x = torch.cat([x, mot[0]], dim=1)
|
||||
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long)
|
||||
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
|
||||
mask_input = torch.cat(
|
||||
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
|
||||
)
|
||||
return x, seq_lens, rope_embs, mask_input
|
||||
return x, rope_embs, mask_input
|
||||
|
||||
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
|
||||
if block_idx in self.audio_injector.injected_block_id.keys():
|
||||
@@ -548,6 +479,118 @@ class WanS2VModel(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
|
||||
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
|
||||
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||
return audio_emb_global, merged_audio_emb
|
||||
|
||||
def get_grid_sizes(self, grid_size_x, grid_size_ref):
|
||||
f, h, w = grid_size_x
|
||||
rf, rh, rw = grid_size_ref
|
||||
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
|
||||
grid_sizes_ref = [[
|
||||
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||
]]
|
||||
return grid_sizes_x + grid_sizes_ref
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_input,
|
||||
motion_latents,
|
||||
pose_cond,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False
|
||||
):
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
# context embedding
|
||||
context = self.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||
seq_len_x = x.shape[1]
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(
|
||||
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
|
||||
)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
|
||||
x = x + self.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# t_mod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
|
||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
x = self.head(x, t[:-1])
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VModelStateDictConverter()
|
||||
|
||||
Reference in New Issue
Block a user