fix bug for usp with refimage

This commit is contained in:
mi804
2025-06-16 19:38:45 +08:00
parent 46f052375f
commit 551721658b
2 changed files with 5 additions and 5 deletions

View File

@@ -1161,13 +1161,13 @@ def model_fn_wan_video(
if tea_cache is not None:
tea_cache.store(x)
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
# Remove reference latents
if reference_latents is not None:
x = x[:, reference_latents.shape[1]:]
f -= 1
x = dit.unpatchify(x, (f, h, w))
return x

View File

@@ -173,7 +173,7 @@ Wan supports multiple acceleration techniques, including:
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
```shell
pip install xfuser>=0.4.3
pip install "xfuser[flash-attn]>=0.4.3"
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
```