wans2v refactor

This commit is contained in:
mi804
2025-08-27 16:18:22 +08:00
parent 8a0bd7c377
commit 4147473c81
2 changed files with 183 additions and 156 deletions

View File

@@ -927,24 +927,23 @@ class WanVideoUnit_S2V(PipelineUnit):
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
# TODO: may support input motion latents
# TODO: may support input motion latents, which related to `drop_motion_frames = False`
motion_frames = 73
lat_motion_frames = (motion_frames + 3) // 4 # 19
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
lat_motion_frames = (motion_frames + 3) // 4
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]}
return {"motion_latents": motion_latents}
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
if s2v_pose_video is None:
input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype)
else:
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
return {"pose_cond": None}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
# encode to latents
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"pose_cond": input_latents[:,:,1:]}
@@ -1084,7 +1083,6 @@ def model_fn_wan_video(
vace_scale = 1.0,
audio_input: Optional[torch.Tensor] = None,
motion_latents: Optional[torch.Tensor] = None,
motion_frames: Optional[list] = None,
pose_cond: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
@@ -1132,10 +1130,10 @@ def model_fn_wan_video(
context=context,
audio_input=audio_input,
motion_latents=motion_latents,
motion_frames=motion_frames,
pose_cond=pose_cond,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
use_gradient_checkpointing=use_gradient_checkpointing,
use_unified_sequence_parallel=use_unified_sequence_parallel,
)
if use_unified_sequence_parallel:
@@ -1265,62 +1263,47 @@ def model_fn_wans2v(
context,
audio_input,
motion_latents,
motion_frames,
pose_cond,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False
use_gradient_checkpointing=False,
use_unified_sequence_parallel=False,
):
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
origin_ref_latents = latents[:, :, 0:1]
latents = latents[:, :, 1:]
x = latents[:, :, 1:]
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input)
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
# context embedding
context = dit.text_embedding(context)
# audio encode
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input)
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
seq_len_x = x.shape[1]
# reference image
x = latents
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
seq_lens = torch.tensor([x.size(1)], dtype=torch.long)
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
ref_grid_sizes = [[
torch.tensor([30, 0, 0]).unsqueeze(0),
torch.tensor([31, rh, rw]).unsqueeze(0),
torch.tensor([1, rh, rw]).unsqueeze(0),
]]
original_seq_len = seq_lens[0]
seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long)
grid_sizes = grid_sizes + ref_grid_sizes
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
x = torch.cat([x, ref_latents], dim=1)
mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device)
mask[:, -ref_latents.shape[1]:] = 1
# mask
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
# freqs
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
# motion
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads
pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None)
x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
# t_mod
if dit.zero_timestep:
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
e0 = dit.time_projection(e).unflatten(1, (6, dit.dim))
if dit.zero_timestep:
e = e[:-1]
zero_e0 = e0[-1:]
e0 = e0[:-1]
e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2)
e0 = [e0, original_seq_len]
# context
context = dit.text_embedding(context)
# tmod
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
@@ -1332,31 +1315,32 @@ def model_fn_wans2v(
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, e0, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = x[:, :original_seq_len]
x = dit.head(x, e)
x = x[:, :seq_len_x]
x = dit.head(x, t[:-1])
x = dit.unpatchify(x, (f, h, w))
# make compatible with wan video
x = torch.cat([origin_ref_latents, x], dim=2)
return x