mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -143,10 +143,8 @@ class BasePipeline(torch.nn.Module):
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_free_vram(self):
|
||||
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
allocated_memory = torch.cuda.device_memory_used(self.device)
|
||||
return (total_memory - allocated_memory) / (1024 ** 3)
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
@@ -247,7 +245,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_free_vram()
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder is not None:
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
|
||||
@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
def check_free_vram(self):
|
||||
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
|
||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
def offload(self):
|
||||
|
||||
Reference in New Issue
Block a user