From 00279a83754474a9ce8f782f777a8d1608291fdc Mon Sep 17 00:00:00 2001 From: handoku Date: Tue, 8 Jul 2025 16:43:43 +0800 Subject: [PATCH 1/5] fea : enable wan video usp for arbitrary seq len --- diffsynth/distributed/xdit_context_parallel.py | 16 ++++++++++++---- diffsynth/pipelines/wan_video.py | 15 ++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 2c1a257..dc4cc62 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -26,15 +26,12 @@ def pad_freqs(original_tensor, target_len): def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) - s_per_rank = x.shape[1] x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - freqs = pad_freqs(freqs, s_per_rank * sp_size) - freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + freqs_rank = torch.chunk(freqs, dim=0)[sp_rank] # chunk freqs like x x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) @@ -73,6 +70,9 @@ def usp_dit_forward(self, return custom_forward # Context Parallel + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + seq_lens = [chunk.shape[1] for chunk in chunks] + x = torch.chunk( x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] @@ -98,7 +98,15 @@ def usp_dit_forward(self, x = self.head(x, t) # Context Parallel + max_len = seq_lens[0] + b, s, c = x.shape + if s != max_len: + padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding_tensor], dim=1) + x = get_sp_group().all_gather(x, dim=1) + # remove pad + x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) # unpatchify x = self.unpatchify(x, (f, h, w)) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 8ac1e4e..b9e020a 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -594,7 +594,12 @@ def model_fn_wan_video( # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + seq_lens = [chunk.shape[1] for chunk in chunks] + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + if tea_cache_update: x = tea_cache.update(x) else: @@ -612,6 +617,14 @@ def model_fn_wan_video( x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: + max_len = seq_lens[0] + b, s, c = x.shape + if s != max_len: + padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding_tensor], dim=1) + x = get_sp_group().all_gather(x, dim=1) + # remove pad + x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) x = dit.unpatchify(x, (f, h, w)) return x From 8c558b35265b72c49deaa2c15ae3704fc8f8d412 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 30 Jul 2025 18:44:17 +0800 Subject: [PATCH 2/5] fix modelconfig --- diffsynth/utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index 58733c1..97f3926 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -178,9 +178,9 @@ class ModelConfig: is_folder = False # Download + if self.local_model_path is None: + self.local_model_path = "./models" if not skip_download: - if self.local_model_path is None: - self.local_model_path = "./models" downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) snapshot_download( self.model_id, From 0b860abf1b4b80ba6e0ceaa1a8cfb585b940f333 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 30 Jul 2025 19:07:16 +0800 Subject: [PATCH 3/5] support arbitrary seq len --- .../distributed/xdit_context_parallel.py | 5 +++- diffsynth/pipelines/wan_video.py | 29 ++++++++----------- diffsynth/pipelines/wan_video_new.py | 6 +++- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index dc4cc62..1c44ffc 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -26,12 +26,15 @@ def pad_freqs(original_tensor, target_len): def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) + sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - freqs_rank = torch.chunk(freqs, dim=0)[sp_rank] # chunk freqs like x + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index b9e020a..1cd716b 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -595,10 +595,9 @@ def model_fn_wan_video( if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) - seq_lens = [chunk.shape[1] for chunk in chunks] - x = torch.chunk( - x, get_sequence_parallel_world_size(), - dim=1)[get_sequence_parallel_rank()] + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) @@ -606,25 +605,21 @@ def model_fn_wan_video( for block_id, block in enumerate(dit.blocks): x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: - x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x) - if reference_latents is not None: - x = x[:, reference_latents.shape[1]:] - f -= 1 - x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: - max_len = seq_lens[0] - b, s, c = x.shape - if s != max_len: - padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) - x = torch.cat([x, padding_tensor], dim=1) - x = get_sp_group().all_gather(x, dim=1) - # remove pad - x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 x = dit.unpatchify(x, (f, h, w)) return x diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index e429456..167acbf 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1074,7 +1074,10 @@ def model_fn_wan_video( # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: @@ -1111,6 +1114,7 @@ def model_fn_wan_video( if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x # Remove reference latents if reference_latents is not None: x = x[:, reference_latents.shape[1]:] From e4178e2501609e6b5eecba4c7b4385a939016e38 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 30 Jul 2025 19:21:21 +0800 Subject: [PATCH 4/5] fix usp dit_forward --- diffsynth/distributed/xdit_context_parallel.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 1c44ffc..4887e2f 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -74,11 +74,9 @@ def usp_dit_forward(self, # Context Parallel chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) - seq_lens = [chunk.shape[1] for chunk in chunks] - - x = torch.chunk( - x, get_sequence_parallel_world_size(), - dim=1)[get_sequence_parallel_rank()] + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] for block in self.blocks: if self.training and use_gradient_checkpointing: @@ -101,15 +99,8 @@ def usp_dit_forward(self, x = self.head(x, t) # Context Parallel - max_len = seq_lens[0] - b, s, c = x.shape - if s != max_len: - padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device) - x = torch.cat([x, padding_tensor], dim=1) - x = get_sp_group().all_gather(x, dim=1) - # remove pad - x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)]) + x = x[:, :-pad_shape] if pad_shape > 0 else x # unpatchify x = self.unpatchify(x, (f, h, w)) From 0954e8a017f829f02a79edbb44e5293bdafda579 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 30 Jul 2025 19:40:08 +0800 Subject: [PATCH 5/5] fix vace usp --- diffsynth/pipelines/wan_video.py | 1 + diffsynth/pipelines/wan_video_new.py | 1 + 2 files changed, 2 insertions(+) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 1cd716b..e70e0cc 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -608,6 +608,7 @@ def model_fn_wan_video( current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 167acbf..2317422 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1106,6 +1106,7 @@ def model_fn_wan_video( current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x)