mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
wans2v refactor
This commit is contained in:
@@ -927,24 +927,23 @@ class WanVideoUnit_S2V(PipelineUnit):
|
||||
|
||||
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
# TODO: may support input motion latents
|
||||
# TODO: may support input motion latents, which related to `drop_motion_frames = False`
|
||||
motion_frames = 73
|
||||
lat_motion_frames = (motion_frames + 3) // 4 # 19
|
||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||
lat_motion_frames = (motion_frames + 3) // 4
|
||||
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]}
|
||||
return {"motion_latents": motion_latents}
|
||||
|
||||
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
if s2v_pose_video is None:
|
||||
input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype)
|
||||
else:
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)
|
||||
# get num_frames-1 frames
|
||||
input_video = input_video[:, :, :num_frames]
|
||||
# pad if not enough frames
|
||||
padding_frames = num_frames - input_video.shape[2]
|
||||
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||
return {"pose_cond": None}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)
|
||||
# get num_frames-1 frames
|
||||
input_video = input_video[:, :, :num_frames]
|
||||
# pad if not enough frames
|
||||
padding_frames = num_frames - input_video.shape[2]
|
||||
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||
# encode to latents
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"pose_cond": input_latents[:,:,1:]}
|
||||
@@ -1084,7 +1083,6 @@ def model_fn_wan_video(
|
||||
vace_scale = 1.0,
|
||||
audio_input: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
motion_frames: Optional[list] = None,
|
||||
pose_cond: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
@@ -1132,10 +1130,10 @@ def model_fn_wan_video(
|
||||
context=context,
|
||||
audio_input=audio_input,
|
||||
motion_latents=motion_latents,
|
||||
motion_frames=motion_frames,
|
||||
pose_cond=pose_cond,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
||||
)
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -1265,62 +1263,47 @@ def model_fn_wans2v(
|
||||
context,
|
||||
audio_input,
|
||||
motion_latents,
|
||||
motion_frames,
|
||||
pose_cond,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False
|
||||
use_gradient_checkpointing=False,
|
||||
use_unified_sequence_parallel=False,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
latents = latents[:, :, 1:]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||
audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input)
|
||||
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||
# context embedding
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input)
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||
seq_len_x = x.shape[1]
|
||||
|
||||
# reference image
|
||||
x = latents
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
|
||||
|
||||
grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||
seq_lens = torch.tensor([x.size(1)], dtype=torch.long)
|
||||
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
|
||||
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
|
||||
|
||||
ref_grid_sizes = [[
|
||||
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||
]]
|
||||
original_seq_len = seq_lens[0]
|
||||
seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long)
|
||||
grid_sizes = grid_sizes + ref_grid_sizes
|
||||
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device)
|
||||
mask[:, -ref_latents.shape[1]:] = 1
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
|
||||
b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads
|
||||
pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None)
|
||||
|
||||
x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# t_mod
|
||||
if dit.zero_timestep:
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
e0 = dit.time_projection(e).unflatten(1, (6, dit.dim))
|
||||
if dit.zero_timestep:
|
||||
e = e[:-1]
|
||||
zero_e0 = e0[-1:]
|
||||
e0 = e0[:-1]
|
||||
e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2)
|
||||
e0 = [e0, original_seq_len]
|
||||
# context
|
||||
context = dit.text_embedding(context)
|
||||
# tmod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1332,31 +1315,32 @@ def model_fn_wans2v(
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, e0, pre_compute_freqs,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, e0, pre_compute_freqs,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, e0, pre_compute_freqs)
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
|
||||
x = x[:, :original_seq_len]
|
||||
x = dit.head(x, e)
|
||||
x = x[:, :seq_len_x]
|
||||
x = dit.head(x, t[:-1])
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user