diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index a144e08..2c1a257 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -124,4 +124,6 @@ def usp_attn_forward(self, x, freqs): ) x = x.flatten(2) + del q, k, v + torch.cuda.empty_cache() return self.o(x) \ No newline at end of file diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index b344423..b6f2c74 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -396,13 +396,13 @@ def model_fn_wan_video( else: tea_cache_update = False + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: - # blocks - if use_unified_sequence_parallel: - if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in dit.blocks: x = block(x, context, t_mod, freqs) if tea_cache is not None: