support arbitrary seq len

This commit is contained in:
mi804
2025-07-30 19:07:16 +08:00
parent 8c558b3526
commit 0b860abf1b
3 changed files with 21 additions and 19 deletions

View File

@@ -1074,7 +1074,10 @@ def model_fn_wan_video(
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
else:
@@ -1111,6 +1114,7 @@ def model_fn_wan_video(
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :-pad_shape] if pad_shape > 0 else x
# Remove reference latents
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]