This commit is contained in:
Artiprocher
2025-12-16 16:09:29 +08:00
parent bfaaf12bf4
commit 1547c3f786
4 changed files with 7 additions and 4 deletions

View File

@@ -177,7 +177,8 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
device = self.device if self.device != "npu" else "npu:0"
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name: