support ascend npu

This commit is contained in:
Artiprocher
2025-12-15 15:48:42 +08:00
parent 78d8842ddf
commit 2883bc1b76
11 changed files with 242 additions and 9 deletions

View File

@@ -3,3 +3,4 @@ from .data import *
from .gradient import *
from .loader import *
from .vram import *
from .device import *

View File

@@ -0,0 +1 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type

View File

@@ -0,0 +1,107 @@
import importlib
import torch
from typing import Any
def is_torch_npu_available():
return importlib.util.find_spec("torch_npu") is not None
IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
if IS_NPU_AVAILABLE:
import torch_npu
torch.npu.config.allow_internal_format = False
def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
return device
def get_torch_device() -> Any:
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
device_name = get_device_type()
try:
return getattr(torch, device_name)
except AttributeError:
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
return torch.cuda
def get_device_id() -> int:
"""Get current device id based on device type."""
return get_torch_device().current_device()
def get_device_name() -> str:
"""Get current device name based on device type."""
return f"{get_device_type()}:{get_device_id()}"
def synchronize() -> None:
"""Execute torch synchronize operation."""
get_torch_device().synchronize()
def empty_cache() -> None:
"""Execute torch empty cache operation."""
get_torch_device().empty_cache()
def get_nccl_backend() -> str:
"""Return distributed communication backend type based on device type."""
if IS_CUDA_AVAILABLE:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
def enable_high_precision_for_bf16():
"""
Set high accumulation dtype for matmul and reduction.
"""
if IS_CUDA_AVAILABLE:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
if IS_NPU_AVAILABLE:
torch.npu.matmul.allow_tf32 = False
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
def parse_device_type(device):
if isinstance(device, str):
if device.startswith("cuda"):
return "cuda"
elif device.startswith("npu"):
return "npu"
else:
return "cpu"
elif isinstance(device, torch.device):
return device.type
def parse_nccl_backend(device_type):
if device_type == "cuda":
return "nccl"
elif device_type == "npu":
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
def get_available_device_type():
return get_device_type()

View File

@@ -2,6 +2,7 @@ import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type
class AutoTorchModule(torch.nn.Module):
@@ -32,6 +33,7 @@ class AutoTorchModule(torch.nn.Module):
)
self.state = 0
self.name = ""
self.computation_device_type = parse_device_type(self.computation_device)
def set_dtype_and_device(
self,
@@ -61,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
@@ -68,6 +68,7 @@ class BasePipeline(torch.nn.Module):
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
@@ -154,7 +155,7 @@ class BasePipeline(torch.nn.Module):
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
torch.cuda.empty_cache()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
@@ -176,7 +177,7 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:

View File

@@ -126,7 +126,7 @@ class WanVideoPipeline(BasePipeline):
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp()
initialize_usp(device)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models

View File

@@ -5,19 +5,20 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ...core.device import parse_nccl_backend, parse_device_type
def initialize_usp():
def initialize_usp(device_type):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
getattr(torch, device_type).set_device(dist.get_rank())
def sinusoidal_embedding_1d(dim, position):
@@ -141,5 +142,5 @@ def usp_attn_forward(self, x, freqs):
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)