diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 7ef4e2e..668b16f 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1161,13 +1161,13 @@ def model_fn_wan_video( if tea_cache is not None: tea_cache.store(x) - if reference_latents is not None: - x = x[:, reference_latents.shape[1]:] - f -= 1 - x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 x = dit.unpatchify(x, (f, h, w)) return x diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 3d71a70..a774773 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -173,7 +173,7 @@ Wan supports multiple acceleration techniques, including: * **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command: ```shell -pip install xfuser>=0.4.3 +pip install "xfuser[flash-attn]>=0.4.3" torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py ```