diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index 01ade0e..751792d 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module): return r def check_free_vram(self): - gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device) + device = self.computation_device if self.computation_device != "npu" else "npu:0" + gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device) used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) return used_memory < self.vram_limit @@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.lora_B_weights = [] self.lora_merger = None self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + self.computation_device_type = parse_device_type(self.computation_device) if offload_dtype == "disk": self.disk_map = disk_map diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 0140497..fa355a1 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -177,7 +177,8 @@ class BasePipeline(torch.nn.Module): def get_vram(self): - return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3) + device = self.device if self.device != "npu" else "npu:0" + return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3) def get_module(self, model, name): if "." in name: diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md index 2f206eb..789d26a 100644 --- a/docs/en/Pipeline_Usage/GPU_support.md +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, -+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, ++ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2, ) video = pipe( diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md index 3ba76fc..56d78f7 100644 --- a/docs/zh/Pipeline_Usage/GPU_support.md +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, -+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, ++ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2, ) video = pipe(