fix vace usp

This commit is contained in:
mi804
2025-07-30 19:40:08 +08:00
parent e4178e2501
commit 0954e8a017
2 changed files with 2 additions and 0 deletions

View File

@@ -608,6 +608,7 @@ def model_fn_wan_video(
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale
if tea_cache is not None:
tea_cache.store(x)

View File

@@ -1106,6 +1106,7 @@ def model_fn_wan_video(
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale
if tea_cache is not None:
tea_cache.store(x)