mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
fix usp dit_forward
This commit is contained in:
@@ -74,11 +74,9 @@ def usp_dit_forward(self,
|
||||
|
||||
# Context Parallel
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
seq_lens = [chunk.shape[1] for chunk in chunks]
|
||||
|
||||
x = torch.chunk(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
@@ -101,15 +99,8 @@ def usp_dit_forward(self,
|
||||
x = self.head(x, t)
|
||||
|
||||
# Context Parallel
|
||||
max_len = seq_lens[0]
|
||||
b, s, c = x.shape
|
||||
if s != max_len:
|
||||
padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device)
|
||||
x = torch.cat([x, padding_tensor], dim=1)
|
||||
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
# remove pad
|
||||
x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)])
|
||||
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
|
||||
Reference in New Issue
Block a user