diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index dc4cc62..1c44ffc 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -26,12 +26,15 @@ def pad_freqs(original_tensor, target_len): def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) + sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - freqs_rank = torch.chunk(freqs, dim=0)[sp_rank] # chunk freqs like x + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index b9e020a..1cd716b 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -595,10 +595,9 @@ def model_fn_wan_video( if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) - seq_lens = [chunk.shape[1] for chunk in chunks] - x = torch.chunk( - x, get_sequence_parallel_world_size(), - dim=1)[get_sequence_parallel_rank()] + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) @@ -606,25 +605,21 @@ def model_fn_wan_video( for block_id, block in enumerate(dit.blocks): x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: - x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x) - if reference_latents is not None: - x = x[:, reference_latents.shape[1]:] - f -= 1 - x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: - max_len = seq_lens[0] - b, s, c = x.shape - if s != max_len: - padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) - x = torch.cat([x, padding_tensor], dim=1) - x = get_sp_group().all_gather(x, dim=1) - # remove pad - x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 x = dit.unpatchify(x, (f, h, w)) return x diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index e429456..167acbf 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1074,7 +1074,10 @@ def model_fn_wan_video( # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: @@ -1111,6 +1114,7 @@ def model_fn_wan_video( if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x # Remove reference latents if reference_latents is not None: x = x[:, reference_latents.shape[1]:]