diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 1113980..7ef4e2e 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1154,7 +1154,10 @@ def model_fn_wan_video( else: x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: - x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x)