support ascend npu

This commit is contained in:
Artiprocher
2025-12-15 15:48:42 +08:00
parent 78d8842ddf
commit 2883bc1b76
11 changed files with 242 additions and 9 deletions

View File

@@ -2,6 +2,7 @@ import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type
class AutoTorchModule(torch.nn.Module):
@@ -32,6 +33,7 @@ class AutoTorchModule(torch.nn.Module):
)
self.state = 0
self.name = ""
self.computation_device_type = parse_device_type(self.computation_device)
def set_dtype_and_device(
self,
@@ -61,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit