fea : enable wan video usp for arbitrary seq len

This commit is contained in:
handoku
2025-07-08 16:43:43 +08:00
parent 89397c755a
commit 00279a8375
2 changed files with 26 additions and 5 deletions

View File

@@ -594,7 +594,12 @@ def model_fn_wan_video(
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
seq_lens = [chunk.shape[1] for chunk in chunks]
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
else:
@@ -612,6 +617,14 @@ def model_fn_wan_video(
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
max_len = seq_lens[0]
b, s, c = x.shape
if s != max_len:
padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding_tensor], dim=1)
x = get_sp_group().all_gather(x, dim=1)
# remove pad
x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)])
x = dit.unpatchify(x, (f, h, w))
return x