mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
support arbitrary seq len
This commit is contained in:
@@ -1074,7 +1074,10 @@ def model_fn_wan_video(
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
@@ -1111,6 +1114,7 @@ def model_fn_wan_video(
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||
# Remove reference latents
|
||||
if reference_latents is not None:
|
||||
x = x[:, reference_latents.shape[1]:]
|
||||
|
||||
Reference in New Issue
Block a user