mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
108 lines
2.8 KiB
Python
108 lines
2.8 KiB
Python
import importlib
|
|
import torch
|
|
from typing import Any
|
|
|
|
|
|
def is_torch_npu_available():
|
|
return importlib.util.find_spec("torch_npu") is not None
|
|
|
|
|
|
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
|
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
|
|
|
if IS_NPU_AVAILABLE:
|
|
import torch_npu
|
|
|
|
torch.npu.config.allow_internal_format = False
|
|
|
|
|
|
def get_device_type() -> str:
|
|
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
|
if IS_CUDA_AVAILABLE:
|
|
device = "cuda"
|
|
elif IS_NPU_AVAILABLE:
|
|
device = "npu"
|
|
else:
|
|
device = "cpu"
|
|
|
|
return device
|
|
|
|
|
|
def get_torch_device() -> Any:
|
|
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
|
device_name = get_device_type()
|
|
|
|
try:
|
|
return getattr(torch, device_name)
|
|
except AttributeError:
|
|
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
|
return torch.cuda
|
|
|
|
|
|
def get_device_id() -> int:
|
|
"""Get current device id based on device type."""
|
|
return get_torch_device().current_device()
|
|
|
|
|
|
def get_device_name() -> str:
|
|
"""Get current device name based on device type."""
|
|
return f"{get_device_type()}:{get_device_id()}"
|
|
|
|
|
|
def synchronize() -> None:
|
|
"""Execute torch synchronize operation."""
|
|
get_torch_device().synchronize()
|
|
|
|
|
|
def empty_cache() -> None:
|
|
"""Execute torch empty cache operation."""
|
|
get_torch_device().empty_cache()
|
|
|
|
|
|
def get_nccl_backend() -> str:
|
|
"""Return distributed communication backend type based on device type."""
|
|
if IS_CUDA_AVAILABLE:
|
|
return "nccl"
|
|
elif IS_NPU_AVAILABLE:
|
|
return "hccl"
|
|
else:
|
|
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
|
|
|
|
|
def enable_high_precision_for_bf16():
|
|
"""
|
|
Set high accumulation dtype for matmul and reduction.
|
|
"""
|
|
if IS_CUDA_AVAILABLE:
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
|
|
|
if IS_NPU_AVAILABLE:
|
|
torch.npu.matmul.allow_tf32 = False
|
|
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
|
|
|
|
|
def parse_device_type(device):
|
|
if isinstance(device, str):
|
|
if device.startswith("cuda"):
|
|
return "cuda"
|
|
elif device.startswith("npu"):
|
|
return "npu"
|
|
else:
|
|
return "cpu"
|
|
elif isinstance(device, torch.device):
|
|
return device.type
|
|
|
|
|
|
def parse_nccl_backend(device_type):
|
|
if device_type == "cuda":
|
|
return "nccl"
|
|
elif device_type == "npu":
|
|
return "hccl"
|
|
else:
|
|
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
|
|
|
|
|
def get_available_device_type():
|
|
return get_device_type()
|