Merge pull request #1221 from Feng0w0/usp_npu

[NPU]:Support USP feature in NPU
This commit is contained in:
Zhongjie Duan
2026-02-04 13:25:24 +08:00
committed by GitHub
3 changed files with 28 additions and 2 deletions

View File

@@ -1,10 +1,13 @@
import torch import torch
from typing import Optional from typing import Optional
from einops import rearrange from einops import rearrange
from yunchang.kernels import AttnType
from xfuser.core.distributed import (get_sequence_parallel_rank, from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size, get_sequence_parallel_world_size,
get_sp_group) get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type from ...core.device import parse_nccl_backend, parse_device_type
@@ -30,13 +33,16 @@ def sinusoidal_embedding_1d(dim, position):
def pad_freqs(original_tensor, target_len): def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
original_tensor_device = original_tensor.device
if original_tensor.device == "npu":
original_tensor = original_tensor.cpu()
padding_tensor = torch.ones( padding_tensor = torch.ones(
pad_size, pad_size,
s1, s1,
s2, s2,
dtype=original_tensor.dtype, dtype=original_tensor.dtype,
device=original_tensor.device) device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
return padded_tensor return padded_tensor
def rope_apply(x, freqs, num_heads): def rope_apply(x, freqs, num_heads):
@@ -133,7 +139,12 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()( attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
attn_type = AttnType.NPU
ring_impl_type = "basic_npu"
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
None, None,
query=q, query=q,
key=k, key=k,

View File

@@ -58,6 +58,14 @@ video = pipe(
save_video(video, "video.mp4", fps=15, quality=5) save_video(video, "video.mp4", fps=15, quality=5)
``` ```
#### USP(Unified Sequence Parallel)
If you want to use this feature on NPU, please install additional third-party libraries as follows:
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```
### Training ### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`. NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.

View File

@@ -58,6 +58,13 @@ video = pipe(
save_video(video, "video.mp4", fps=15, quality=5) save_video(video, "video.mp4", fps=15, quality=5)
``` ```
#### USP(Unified Sequence Parallel)
如果想要在NPU上使用该特性请通过如下方式安装额外的第三方库
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```
### 训练 ### 训练
当前已为每类模型添加NPU的启动脚本样例脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh` 当前已为每类模型添加NPU的启动脚本样例脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`