fix wan2.2 5B usp (#763)

This commit is contained in:
Zhongjie Duan
2025-08-08 16:26:04 +08:00
committed by GitHub

View File

@@ -1021,6 +1021,10 @@ def model_fn_wan_video(
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
]).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
t = t_chunks[get_sequence_parallel_rank()]
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))