diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 44a64c9..9086af2 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -143,10 +143,8 @@ class BasePipeline(torch.nn.Module): self.vram_management_enabled = True - def get_free_vram(self): - total_memory = torch.cuda.get_device_properties(self.device).total_memory - allocated_memory = torch.cuda.device_memory_used(self.device) - return (total_memory - allocated_memory) / (1024 ** 3) + def get_vram(self): + return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) def freeze_except(self, model_names): @@ -247,7 +245,7 @@ class WanVideoPipeline(BasePipeline): vram_limit = None else: if vram_limit is None: - vram_limit = self.get_free_vram() + vram_limit = self.get_vram() vram_limit = vram_limit - vram_buffer if self.text_encoder is not None: dtype = next(iter(self.text_encoder.parameters())).dtype diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index dd4a245..c0beaf8 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module): super().__init__() def check_free_vram(self): - used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3) + gpu_mem_state = torch.cuda.mem_get_info(self.computation_device) + used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3) return used_memory < self.vram_limit def offload(self): diff --git a/requirements.txt b/requirements.txt index c25b688..92d8b48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.6.0 +torch>=2.0.0 torchvision cupy-cuda12x transformers