support ascend npu

This commit is contained in:
Artiprocher
2025-12-15 15:48:42 +08:00
parent 78d8842ddf
commit 2883bc1b76
11 changed files with 242 additions and 9 deletions

View File

@@ -5,19 +5,20 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ...core.device import parse_nccl_backend, parse_device_type
def initialize_usp():
def initialize_usp(device_type):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
getattr(torch, device_type).set_device(dist.get_rank())
def sinusoidal_embedding_1d(dim, position):
@@ -141,5 +142,5 @@ def usp_attn_forward(self, x, freqs):
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)