mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
fix wan2.2 5B usp
This commit is contained in:
@@ -1021,6 +1021,10 @@ def model_fn_wan_video(
|
||||
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
|
||||
]).flatten()
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
|
||||
t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
|
||||
t = t_chunks[get_sequence_parallel_rank()]
|
||||
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
||||
else:
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
|
||||
Reference in New Issue
Block a user