From 7cfadc2ca8c1af873a4aef528b3fb48b416f1b59 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 7 Aug 2025 23:06:52 +0800 Subject: [PATCH] fix wan2.2 5B usp --- diffsynth/pipelines/wan_video_new.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 2317422..89adbdf 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1021,6 +1021,10 @@ def model_fn_wan_video( torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep ]).flatten() t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) else: t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))