mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ascend npu
This commit is contained in:
@@ -3,3 +3,4 @@ from .data import *
|
||||
from .gradient import *
|
||||
from .loader import *
|
||||
from .vram import *
|
||||
from .device import *
|
||||
|
||||
1
diffsynth/core/device/__init__.py
Normal file
1
diffsynth/core/device/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
|
||||
107
diffsynth/core/device/npu_compatible_device.py
Normal file
107
diffsynth/core/device/npu_compatible_device.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import importlib
|
||||
import torch
|
||||
from typing import Any
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
return importlib.util.find_spec("torch_npu") is not None
|
||||
|
||||
|
||||
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def get_device_type() -> str:
|
||||
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
device = "cuda"
|
||||
elif IS_NPU_AVAILABLE:
|
||||
device = "npu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_torch_device() -> Any:
|
||||
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
||||
device_name = get_device_type()
|
||||
|
||||
try:
|
||||
return getattr(torch, device_name)
|
||||
except AttributeError:
|
||||
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
||||
return torch.cuda
|
||||
|
||||
|
||||
def get_device_id() -> int:
|
||||
"""Get current device id based on device type."""
|
||||
return get_torch_device().current_device()
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
"""Get current device name based on device type."""
|
||||
return f"{get_device_type()}:{get_device_id()}"
|
||||
|
||||
|
||||
def synchronize() -> None:
|
||||
"""Execute torch synchronize operation."""
|
||||
get_torch_device().synchronize()
|
||||
|
||||
|
||||
def empty_cache() -> None:
|
||||
"""Execute torch empty cache operation."""
|
||||
get_torch_device().empty_cache()
|
||||
|
||||
|
||||
def get_nccl_backend() -> str:
|
||||
"""Return distributed communication backend type based on device type."""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
return "nccl"
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return "hccl"
|
||||
else:
|
||||
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
||||
|
||||
|
||||
def enable_high_precision_for_bf16():
|
||||
"""
|
||||
Set high accumulation dtype for matmul and reduction.
|
||||
"""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.matmul.allow_tf32 = False
|
||||
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
|
||||
def parse_device_type(device):
|
||||
if isinstance(device, str):
|
||||
if device.startswith("cuda"):
|
||||
return "cuda"
|
||||
elif device.startswith("npu"):
|
||||
return "npu"
|
||||
else:
|
||||
return "cpu"
|
||||
elif isinstance(device, torch.device):
|
||||
return device.type
|
||||
|
||||
|
||||
def parse_nccl_backend(device_type):
|
||||
if device_type == "cuda":
|
||||
return "nccl"
|
||||
elif device_type == "npu":
|
||||
return "hccl"
|
||||
else:
|
||||
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
||||
|
||||
|
||||
def get_available_device_type():
|
||||
return get_device_type()
|
||||
@@ -2,6 +2,7 @@ import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
from ..device import parse_device_type
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
@@ -32,6 +33,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
)
|
||||
self.state = 0
|
||||
self.name = ""
|
||||
self.computation_device_type = parse_device_type(self.computation_device)
|
||||
|
||||
def set_dtype_and_device(
|
||||
self,
|
||||
@@ -61,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
|
||||
Reference in New Issue
Block a user