From 848bfd6993a0e8c1c542ab4c821dbd14d96f07b8 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Wed, 21 Jan 2026 10:25:31 +0800 Subject: [PATCH] [NPU]:Support USP feature in NPU --- diffsynth/utils/xfuser/xdit_context_parallel.py | 15 ++++++++++++--- docs/en/Pipeline_Usage/GPU_support.md | 8 ++++++++ docs/zh/Pipeline_Usage/GPU_support.md | 7 +++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 21dc3b3..94b92c7 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -1,10 +1,13 @@ import torch from typing import Optional from einops import rearrange +from yunchang.kernels import AttnType from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +from ... import IS_NPU_AVAILABLE from ...core.device import parse_nccl_backend, parse_device_type @@ -35,8 +38,9 @@ def pad_freqs(original_tensor, target_len): s1, s2, dtype=original_tensor.dtype, - device=original_tensor.device) - padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + device='cpu') + original_tensor_device = original_tensor.device + padded_tensor = torch.cat([original_tensor.cpu(), padding_tensor], dim=0).to(device=original_tensor_device) return padded_tensor def rope_apply(x, freqs, num_heads): @@ -133,7 +137,12 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) - x = xFuserLongContextAttention()( + attn_type = AttnType.FA + ring_impl_type = "basic" + if IS_NPU_AVAILABLE: + attn_type = AttnType.NPU + ring_impl_type = "basic_npu" + x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)( None, query=q, key=k, diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md index aba5706..2c67fed 100644 --- a/docs/en/Pipeline_Usage/GPU_support.md +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -58,6 +58,14 @@ video = pipe( save_video(video, "video.mp4", fps=15, quality=5) ``` +#### USP(Unified Sequence Parallel) +If you want to use this feature on NPU, please install additional third-party libraries as follows: +```shell +pip install git+https://github.com/feifeibear/long-context-attention.git +pip install git+https://github.com/xdit-project/xDiT.git +``` + + ### Training NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`. diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md index 8124147..b5c0e33 100644 --- a/docs/zh/Pipeline_Usage/GPU_support.md +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -58,6 +58,13 @@ video = pipe( save_video(video, "video.mp4", fps=15, quality=5) ``` +#### USP(Unified Sequence Parallel) +如果想要在NPU上使用该特性,请通过如下方式安装额外的第三方库: +```shell +pip install git+https://github.com/feifeibear/long-context-attention.git +pip install git+https://github.com/xdit-project/xDiT.git +``` + ### 训练 当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。