wans2v refactor

This commit is contained in:
mi804
2025-08-27 16:18:22 +08:00
parent 8a0bd7c377
commit 4147473c81
2 changed files with 183 additions and 156 deletions

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .utils import hash_state_dict_keys
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
def torch_dfs(model: nn.Module, parent_name='root'):
@@ -24,22 +24,6 @@ def torch_dfs(model: nn.Module, parent_name='root'):
return modules, module_names
def rope_apply(x, freqs):
n, c = x.size(2), x.size(3) // 2
# loop over samples
output = []
for i, _ in enumerate(x):
s = x.size(1)
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
freqs_i = freqs[i, :s]
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).to(x.dtype)
def rope_precompute(x, grid_sizes, freqs, start=None):
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
@@ -135,11 +119,8 @@ class MotionEncoder_tc(nn.Module):
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
@@ -358,71 +339,21 @@ class CausalAudioEncoder(nn.Module):
return res # b f n dim
class WanS2VSelfAttention(SelfAttention):
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x)
q = rope_apply(q, freqs)
k = rope_apply(k, freqs)
x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v)
return self.o(x)
class WanS2VDiTBlock(DiTBlock):
def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False):
super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps)
self.self_attn = WanS2VSelfAttention(dim, num_heads, eps)
def forward(self, x, context, e, freqs):
seg_idx = e[1].item()
seg_idx = min(max(0, seg_idx), x.size(1))
seg_idx = [0, seg_idx, x.size(1)]
e = e[0]
modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device)
e = (modulation + e).chunk(6, dim=1)
e = [element.squeeze(1) for element in e]
norm_x = self.norm1(x)
parts = []
for i in range(2):
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
norm_x = torch.cat(parts, dim=1)
# self-attention
y = self.self_attn(norm_x, freqs)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
# cross-attention & ffn function
def cross_attn_ffn(x, context, e):
x = x + self.cross_attn(self.norm3(x), context)
norm2_x = self.norm2(x)
parts = []
for i in range(2):
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
norm2_x = torch.cat(parts, dim=1)
y = self.ffn(norm2_x)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
return x
x = cross_attn_ffn(x, context, e)
return x
class S2VHead(Head):
def forward(self, x, t_mod):
t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0]))
def forward(self, x, context, t_mod, seq_len_x, freqs):
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
t_mod = [
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
for element in t_mod
]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
@@ -472,9 +403,9 @@ class WanS2VModel(torch.nn.Module):
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = S2VHead(dim, out_dim, patch_size, eps)
self.freqs = precompute_freqs_cis_3d(dim // num_heads)
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = Head(dim, out_dim, patch_size, eps)
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
@@ -516,17 +447,17 @@ class WanS2VModel(torch.nn.Module):
else:
return flattern_mot, mot_remb
def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
# TODO: check drop_motion_frames = False
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long)
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
mask_input = torch.cat(
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
)
return x, seq_lens, rope_embs, mask_input
return x, rope_embs, mask_input
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
if block_idx in self.audio_injector.injected_block_id.keys():
@@ -548,6 +479,118 @@ class WanS2VModel(torch.nn.Module):
return hidden_states
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
return audio_emb_global, merged_audio_emb
def get_grid_sizes(self, grid_size_x, grid_size_ref):
f, h, w = grid_size_x
rf, rh, rw = grid_size_ref
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
grid_sizes_ref = [[
torch.tensor([30, 0, 0]).unsqueeze(0),
torch.tensor([31, rh, rw]).unsqueeze(0),
torch.tensor([1, rh, rw]).unsqueeze(0),
]]
return grid_sizes_x + grid_sizes_ref
def forward(
self,
latents,
timestep,
context,
audio_input,
motion_latents,
pose_cond,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False
):
origin_ref_latents = latents[:, :, 0:1]
x = latents[:, :, 1:]
# context embedding
context = self.text_embedding(context)
# audio encode
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
seq_len_x = x.shape[1]
# reference image
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
x = torch.cat([x, ref_latents], dim=1)
# mask
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
# freqs
pre_compute_freqs = rope_precompute(
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
)
# motion
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x = x + self.trainable_cond_mask(mask).to(x.dtype)
# t_mod
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = x[:, :seq_len_x]
x = self.head(x, t[:-1])
x = self.unpatchify(x, (f, h, w))
# make compatible with wan video
x = torch.cat([origin_ref_latents, x], dim=2)
return x
@staticmethod
def state_dict_converter():
return WanS2VModelStateDictConverter()