wans2v refactor

This commit is contained in:
mi804
2025-08-27 16:18:22 +08:00
parent 8a0bd7c377
commit 4147473c81
2 changed files with 183 additions and 156 deletions

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .utils import hash_state_dict_keys
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
def torch_dfs(model: nn.Module, parent_name='root'):
@@ -24,22 +24,6 @@ def torch_dfs(model: nn.Module, parent_name='root'):
return modules, module_names
def rope_apply(x, freqs):
n, c = x.size(2), x.size(3) // 2
# loop over samples
output = []
for i, _ in enumerate(x):
s = x.size(1)
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
freqs_i = freqs[i, :s]
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).to(x.dtype)
def rope_precompute(x, grid_sizes, freqs, start=None):
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
@@ -135,11 +119,8 @@ class MotionEncoder_tc(nn.Module):
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
@@ -358,71 +339,21 @@ class CausalAudioEncoder(nn.Module):
return res # b f n dim
class WanS2VSelfAttention(SelfAttention):
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x)
q = rope_apply(q, freqs)
k = rope_apply(k, freqs)
x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v)
return self.o(x)
class WanS2VDiTBlock(DiTBlock):
def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False):
super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps)
self.self_attn = WanS2VSelfAttention(dim, num_heads, eps)
def forward(self, x, context, e, freqs):
seg_idx = e[1].item()
seg_idx = min(max(0, seg_idx), x.size(1))
seg_idx = [0, seg_idx, x.size(1)]
e = e[0]
modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device)
e = (modulation + e).chunk(6, dim=1)
e = [element.squeeze(1) for element in e]
norm_x = self.norm1(x)
parts = []
for i in range(2):
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
norm_x = torch.cat(parts, dim=1)
# self-attention
y = self.self_attn(norm_x, freqs)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
# cross-attention & ffn function
def cross_attn_ffn(x, context, e):
x = x + self.cross_attn(self.norm3(x), context)
norm2_x = self.norm2(x)
parts = []
for i in range(2):
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
norm2_x = torch.cat(parts, dim=1)
y = self.ffn(norm2_x)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
return x
x = cross_attn_ffn(x, context, e)
return x
class S2VHead(Head):
def forward(self, x, t_mod):
t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0]))
def forward(self, x, context, t_mod, seq_len_x, freqs):
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
t_mod = [
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
for element in t_mod
]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
@@ -472,9 +403,9 @@ class WanS2VModel(torch.nn.Module):
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = S2VHead(dim, out_dim, patch_size, eps)
self.freqs = precompute_freqs_cis_3d(dim // num_heads)
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = Head(dim, out_dim, patch_size, eps)
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
@@ -516,17 +447,17 @@ class WanS2VModel(torch.nn.Module):
else:
return flattern_mot, mot_remb
def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
# TODO: check drop_motion_frames = False
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long)
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
mask_input = torch.cat(
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
)
return x, seq_lens, rope_embs, mask_input
return x, rope_embs, mask_input
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
if block_idx in self.audio_injector.injected_block_id.keys():
@@ -548,6 +479,118 @@ class WanS2VModel(torch.nn.Module):
return hidden_states
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
return audio_emb_global, merged_audio_emb
def get_grid_sizes(self, grid_size_x, grid_size_ref):
f, h, w = grid_size_x
rf, rh, rw = grid_size_ref
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
grid_sizes_ref = [[
torch.tensor([30, 0, 0]).unsqueeze(0),
torch.tensor([31, rh, rw]).unsqueeze(0),
torch.tensor([1, rh, rw]).unsqueeze(0),
]]
return grid_sizes_x + grid_sizes_ref
def forward(
self,
latents,
timestep,
context,
audio_input,
motion_latents,
pose_cond,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False
):
origin_ref_latents = latents[:, :, 0:1]
x = latents[:, :, 1:]
# context embedding
context = self.text_embedding(context)
# audio encode
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
seq_len_x = x.shape[1]
# reference image
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
x = torch.cat([x, ref_latents], dim=1)
# mask
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
# freqs
pre_compute_freqs = rope_precompute(
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
)
# motion
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x = x + self.trainable_cond_mask(mask).to(x.dtype)
# t_mod
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = x[:, :seq_len_x]
x = self.head(x, t[:-1])
x = self.unpatchify(x, (f, h, w))
# make compatible with wan video
x = torch.cat([origin_ref_latents, x], dim=2)
return x
@staticmethod
def state_dict_converter():
return WanS2VModelStateDictConverter()

View File

@@ -927,24 +927,23 @@ class WanVideoUnit_S2V(PipelineUnit):
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
# TODO: may support input motion latents
# TODO: may support input motion latents, which related to `drop_motion_frames = False`
motion_frames = 73
lat_motion_frames = (motion_frames + 3) // 4 # 19
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
lat_motion_frames = (motion_frames + 3) // 4
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]}
return {"motion_latents": motion_latents}
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
if s2v_pose_video is None:
input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype)
else:
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
return {"pose_cond": None}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
# encode to latents
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"pose_cond": input_latents[:,:,1:]}
@@ -1084,7 +1083,6 @@ def model_fn_wan_video(
vace_scale = 1.0,
audio_input: Optional[torch.Tensor] = None,
motion_latents: Optional[torch.Tensor] = None,
motion_frames: Optional[list] = None,
pose_cond: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
@@ -1132,10 +1130,10 @@ def model_fn_wan_video(
context=context,
audio_input=audio_input,
motion_latents=motion_latents,
motion_frames=motion_frames,
pose_cond=pose_cond,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
use_gradient_checkpointing=use_gradient_checkpointing,
use_unified_sequence_parallel=use_unified_sequence_parallel,
)
if use_unified_sequence_parallel:
@@ -1265,62 +1263,47 @@ def model_fn_wans2v(
context,
audio_input,
motion_latents,
motion_frames,
pose_cond,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False
use_gradient_checkpointing=False,
use_unified_sequence_parallel=False,
):
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
origin_ref_latents = latents[:, :, 0:1]
latents = latents[:, :, 1:]
x = latents[:, :, 1:]
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input)
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
# context embedding
context = dit.text_embedding(context)
# audio encode
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input)
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
seq_len_x = x.shape[1]
# reference image
x = latents
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
seq_lens = torch.tensor([x.size(1)], dtype=torch.long)
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
ref_grid_sizes = [[
torch.tensor([30, 0, 0]).unsqueeze(0),
torch.tensor([31, rh, rw]).unsqueeze(0),
torch.tensor([1, rh, rw]).unsqueeze(0),
]]
original_seq_len = seq_lens[0]
seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long)
grid_sizes = grid_sizes + ref_grid_sizes
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
x = torch.cat([x, ref_latents], dim=1)
mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device)
mask[:, -ref_latents.shape[1]:] = 1
# mask
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
# freqs
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
# motion
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads
pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None)
x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
# t_mod
if dit.zero_timestep:
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
e0 = dit.time_projection(e).unflatten(1, (6, dit.dim))
if dit.zero_timestep:
e = e[:-1]
zero_e0 = e0[-1:]
e0 = e0[:-1]
e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2)
e0 = [e0, original_seq_len]
# context
context = dit.text_embedding(context)
# tmod
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
@@ -1332,31 +1315,32 @@ def model_fn_wans2v(
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, e0, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = x[:, :original_seq_len]
x = dit.head(x, e)
x = x[:, :seq_len_x]
x = dit.head(x, t[:-1])
x = dit.unpatchify(x, (f, h, w))
# make compatible with wan video
x = torch.cat([origin_ref_latents, x], dim=2)
return x