support torch<2.6.0

This commit is contained in:
Artiprocher
2025-06-16 13:05:54 +08:00
parent 3d2b51554a
commit c0706e3fbd
3 changed files with 6 additions and 7 deletions

View File

@@ -143,10 +143,8 @@ class BasePipeline(torch.nn.Module):
self.vram_management_enabled = True
def get_free_vram(self):
total_memory = torch.cuda.get_device_properties(self.device).total_memory
allocated_memory = torch.cuda.device_memory_used(self.device)
return (total_memory - allocated_memory) / (1024 ** 3)
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
def freeze_except(self, model_names):
@@ -247,7 +245,7 @@ class WanVideoPipeline(BasePipeline):
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_free_vram()
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype

View File

@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
super().__init__()
def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):

View File

@@ -1,4 +1,4 @@
torch>=2.6.0
torch>=2.0.0
torchvision
cupy-cuda12x
transformers