support arbitrary seq len

This commit is contained in:
mi804
2025-07-30 19:07:16 +08:00
parent 8c558b3526
commit 0b860abf1b
3 changed files with 21 additions and 19 deletions

View File

@@ -595,10 +595,9 @@ def model_fn_wan_video(
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
seq_lens = [chunk.shape[1] for chunk in chunks]
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
@@ -606,25 +605,21 @@ def model_fn_wan_video(
for block_id, block in enumerate(dit.blocks):
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
x = x + current_vace_hint * vace_scale
if tea_cache is not None:
tea_cache.store(x)
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
max_len = seq_lens[0]
b, s, c = x.shape
if s != max_len:
padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding_tensor], dim=1)
x = get_sp_group().all_gather(x, dim=1)
# remove pad
x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)])
x = x[:, :-pad_shape] if pad_shape > 0 else x
# Remove reference latents
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.unpatchify(x, (f, h, w))
return x

View File

@@ -1074,7 +1074,10 @@ def model_fn_wan_video(
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
else:
@@ -1111,6 +1114,7 @@ def model_fn_wan_video(
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :-pad_shape] if pad_shape > 0 else x
# Remove reference latents
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]