mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ascend npu
This commit is contained in:
@@ -3,3 +3,4 @@ from .data import *
|
||||
from .gradient import *
|
||||
from .loader import *
|
||||
from .vram import *
|
||||
from .device import *
|
||||
|
||||
1
diffsynth/core/device/__init__.py
Normal file
1
diffsynth/core/device/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
|
||||
107
diffsynth/core/device/npu_compatible_device.py
Normal file
107
diffsynth/core/device/npu_compatible_device.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import importlib
|
||||
import torch
|
||||
from typing import Any
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
return importlib.util.find_spec("torch_npu") is not None
|
||||
|
||||
|
||||
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def get_device_type() -> str:
|
||||
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
device = "cuda"
|
||||
elif IS_NPU_AVAILABLE:
|
||||
device = "npu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_torch_device() -> Any:
|
||||
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
||||
device_name = get_device_type()
|
||||
|
||||
try:
|
||||
return getattr(torch, device_name)
|
||||
except AttributeError:
|
||||
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
||||
return torch.cuda
|
||||
|
||||
|
||||
def get_device_id() -> int:
|
||||
"""Get current device id based on device type."""
|
||||
return get_torch_device().current_device()
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
"""Get current device name based on device type."""
|
||||
return f"{get_device_type()}:{get_device_id()}"
|
||||
|
||||
|
||||
def synchronize() -> None:
|
||||
"""Execute torch synchronize operation."""
|
||||
get_torch_device().synchronize()
|
||||
|
||||
|
||||
def empty_cache() -> None:
|
||||
"""Execute torch empty cache operation."""
|
||||
get_torch_device().empty_cache()
|
||||
|
||||
|
||||
def get_nccl_backend() -> str:
|
||||
"""Return distributed communication backend type based on device type."""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
return "nccl"
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return "hccl"
|
||||
else:
|
||||
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
||||
|
||||
|
||||
def enable_high_precision_for_bf16():
|
||||
"""
|
||||
Set high accumulation dtype for matmul and reduction.
|
||||
"""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.matmul.allow_tf32 = False
|
||||
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
|
||||
def parse_device_type(device):
|
||||
if isinstance(device, str):
|
||||
if device.startswith("cuda"):
|
||||
return "cuda"
|
||||
elif device.startswith("npu"):
|
||||
return "npu"
|
||||
else:
|
||||
return "cpu"
|
||||
elif isinstance(device, torch.device):
|
||||
return device.type
|
||||
|
||||
|
||||
def parse_nccl_backend(device_type):
|
||||
if device_type == "cuda":
|
||||
return "nccl"
|
||||
elif device_type == "npu":
|
||||
return "hccl"
|
||||
else:
|
||||
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
||||
|
||||
|
||||
def get_available_device_type():
|
||||
return get_device_type()
|
||||
@@ -2,6 +2,7 @@ import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
from ..device import parse_device_type
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
@@ -32,6 +33,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
)
|
||||
self.state = 0
|
||||
self.name = ""
|
||||
self.computation_device_type = parse_device_type(self.computation_device)
|
||||
|
||||
def set_dtype_and_device(
|
||||
self,
|
||||
@@ -61,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import numpy as np
|
||||
from einops import repeat, reduce
|
||||
from typing import Union
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
@@ -68,6 +68,7 @@ class BasePipeline(torch.nn.Module):
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device_type = parse_device_type(device)
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
@@ -154,7 +155,7 @@ class BasePipeline(torch.nn.Module):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
torch.cuda.empty_cache()
|
||||
getattr(torch, self.device_type).empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
@@ -176,7 +177,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
if "." in name:
|
||||
|
||||
@@ -126,7 +126,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
if use_usp:
|
||||
from ..utils.xfuser import initialize_usp
|
||||
initialize_usp()
|
||||
initialize_usp(device)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
|
||||
@@ -5,19 +5,20 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
from ...core.device import parse_nccl_backend, parse_device_type
|
||||
|
||||
|
||||
def initialize_usp():
|
||||
def initialize_usp(device_type):
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||
dist.init_process_group(backend="nccl", init_method="env://")
|
||||
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
|
||||
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
getattr(torch, device_type).set_device(dist.get_rank())
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
@@ -141,5 +142,5 @@ def usp_attn_forward(self, x, freqs):
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
||||
return self.o(x)
|
||||
Reference in New Issue
Block a user