From e4178e2501609e6b5eecba4c7b4385a939016e38 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 30 Jul 2025 19:21:21 +0800 Subject: [PATCH] fix usp dit_forward --- diffsynth/distributed/xdit_context_parallel.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 1c44ffc..4887e2f 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -74,11 +74,9 @@ def usp_dit_forward(self, # Context Parallel chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) - seq_lens = [chunk.shape[1] for chunk in chunks] - - x = torch.chunk( - x, get_sequence_parallel_world_size(), - dim=1)[get_sequence_parallel_rank()] + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] for block in self.blocks: if self.training and use_gradient_checkpointing: @@ -101,15 +99,8 @@ def usp_dit_forward(self, x = self.head(x, t) # Context Parallel - max_len = seq_lens[0] - b, s, c = x.shape - if s != max_len: - padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) - x = torch.cat([x, padding_tensor], dim=1) - x = get_sp_group().all_gather(x, dim=1) - # remove pad - x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) + x = x[:, :-pad_shape] if pad_shape > 0 else x # unpatchify x = self.unpatchify(x, (f, h, w))