fix bug for usp with refimage

This commit is contained in:
mi804
2025-06-16 19:38:45 +08:00
parent 46f052375f
commit 551721658b
2 changed files with 5 additions and 5 deletions

View File

@@ -1161,13 +1161,13 @@ def model_fn_wan_video(
if tea_cache is not None:
tea_cache.store(x)
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
# Remove reference latents
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.unpatchify(x, (f, h, w))
return x