mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ascend npu
This commit is contained in:
@@ -5,19 +5,20 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
from ...core.device import parse_nccl_backend, parse_device_type
|
||||
|
||||
|
||||
def initialize_usp():
|
||||
def initialize_usp(device_type):
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||
dist.init_process_group(backend="nccl", init_method="env://")
|
||||
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
|
||||
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
getattr(torch, device_type).set_device(dist.get_rank())
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
@@ -141,5 +142,5 @@ def usp_attn_forward(self, x, freqs):
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
||||
return self.o(x)
|
||||
Reference in New Issue
Block a user