vram management support torch<2.6.0 (#613)

support torch<2.6.0
This commit is contained in:
Zhongjie Duan
2025-06-16 13:08:29 +08:00
committed by GitHub
parent 8584e50309
commit c164519ef1
2 changed files with 5 additions and 6 deletions

View File

@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
super().__init__()
def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):