mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
[NPU]:Support USP feature in NPU
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from yunchang.kernels import AttnType
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
from ... import IS_NPU_AVAILABLE
|
||||
from ...core.device import parse_nccl_backend, parse_device_type
|
||||
|
||||
|
||||
@@ -35,8 +38,9 @@ def pad_freqs(original_tensor, target_len):
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
device='cpu')
|
||||
original_tensor_device = original_tensor.device
|
||||
padded_tensor = torch.cat([original_tensor.cpu(), padding_tensor], dim=0).to(device=original_tensor_device)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
@@ -133,7 +137,12 @@ def usp_attn_forward(self, x, freqs):
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
attn_type = AttnType.FA
|
||||
ring_impl_type = "basic"
|
||||
if IS_NPU_AVAILABLE:
|
||||
attn_type = AttnType.NPU
|
||||
ring_impl_type = "basic_npu"
|
||||
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
|
||||
Reference in New Issue
Block a user