support torch<2.6.0

This commit is contained in:
Artiprocher
2025-06-16 13:05:54 +08:00
parent 3d2b51554a
commit c0706e3fbd
3 changed files with 6 additions and 7 deletions

View File

@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
super().__init__()
def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):