support ascend npu

This commit is contained in:
Artiprocher
2025-12-15 15:48:42 +08:00
parent 78d8842ddf
commit 2883bc1b76
11 changed files with 242 additions and 9 deletions

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
@@ -68,6 +68,7 @@ class BasePipeline(torch.nn.Module):
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
@@ -154,7 +155,7 @@ class BasePipeline(torch.nn.Module):
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
torch.cuda.empty_cache()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
@@ -176,7 +177,7 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name: