wans2v usp

This commit is contained in:
mi804
2025-08-27 19:50:33 +08:00
parent 4147473c81
commit fdeb363fa2
3 changed files with 42 additions and 20 deletions

View File

@@ -1284,11 +1284,11 @@ def model_fn_wans2v(
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
seq_len_x = x.shape[1]
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
# reference image
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
x = torch.cat([x, ref_latents], dim=1)
# mask
@@ -1305,6 +1305,14 @@ def model_fn_wans2v(
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
x = torch.chunk(x, world_size, dim=1)[sp_rank]
seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
seq_len_x = seq_len_x_list[sp_rank]
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
@@ -1315,7 +1323,7 @@ def model_fn_wans2v(
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
@@ -1326,7 +1334,7 @@ def model_fn_wans2v(
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
@@ -1335,10 +1343,13 @@ def model_fn_wans2v(
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
x = x[:, :seq_len_x]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :seq_len_x_global]
x = dit.head(x, t[:-1])
x = dit.unpatchify(x, (f, h, w))
# make compatible with wan video