fix vace usp

This commit is contained in:
mi804
2025-06-16 18:54:29 +08:00
parent 4c052e42bc
commit 46f052375f

View File

@@ -1154,7 +1154,10 @@ def model_fn_wan_video(
else:
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
x = x + current_vace_hint * vace_scale
if tea_cache is not None:
tea_cache.store(x)