This commit is contained in:
Artiprocher
2025-12-16 16:09:29 +08:00
parent bfaaf12bf4
commit 1547c3f786
4 changed files with 7 additions and 4 deletions

View File

@@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
device = self.computation_device if self.computation_device != "npu" else "npu:0"
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
@@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk":
self.disk_map = disk_map