Merge pull request #1136 from modelscope/bugfix-device

bugfix
This commit is contained in:
Zhongjie Duan
2025-12-16 16:12:05 +08:00
committed by GitHub
4 changed files with 7 additions and 4 deletions

View File

@@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device)
device = self.computation_device if self.computation_device != "npu" else "npu:0"
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
@@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk":
self.disk_map = disk_map

View File

@@ -177,7 +177,8 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
device = self.device if self.device != "npu" else "npu:0"
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:

View File

@@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
)
video = pipe(

View File

@@ -46,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
)
video = pipe(