From 00279a83754474a9ce8f782f777a8d1608291fdc Mon Sep 17 00:00:00 2001 From: handoku Date: Tue, 8 Jul 2025 16:43:43 +0800 Subject: [PATCH] fea : enable wan video usp for arbitrary seq len --- diffsynth/distributed/xdit_context_parallel.py | 16 ++++++++++++---- diffsynth/pipelines/wan_video.py | 15 ++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 2c1a257..dc4cc62 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -26,15 +26,12 @@ def pad_freqs(original_tensor, target_len): def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) - s_per_rank = x.shape[1] x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - freqs = pad_freqs(freqs, s_per_rank * sp_size) - freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + freqs_rank = torch.chunk(freqs, dim=0)[sp_rank] # chunk freqs like x x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) @@ -73,6 +70,9 @@ def usp_dit_forward(self, return custom_forward # Context Parallel + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + seq_lens = [chunk.shape[1] for chunk in chunks] + x = torch.chunk( x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] @@ -98,7 +98,15 @@ def usp_dit_forward(self, x = self.head(x, t) # Context Parallel + max_len = seq_lens[0] + b, s, c = x.shape + if s != max_len: + padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding_tensor], dim=1) + x = get_sp_group().all_gather(x, dim=1) + # remove pad + x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) # unpatchify x = self.unpatchify(x, (f, h, w)) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 8ac1e4e..b9e020a 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -594,7 +594,12 @@ def model_fn_wan_video( # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + seq_lens = [chunk.shape[1] for chunk in chunks] + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + if tea_cache_update: x = tea_cache.update(x) else: @@ -612,6 +617,14 @@ def model_fn_wan_video( x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: + max_len = seq_lens[0] + b, s, c = x.shape + if s != max_len: + padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding_tensor], dim=1) + x = get_sp_group().all_gather(x, dim=1) + # remove pad + x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) x = dit.unpatchify(x, (f, h, w)) return x