Merge branch 'modelscope:main' into wan_rope

This commit is contained in:
Feng
2026-01-12 11:21:09 +08:00
committed by GitHub
17 changed files with 322 additions and 19 deletions

View File

@@ -1 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE

View File

@@ -2,7 +2,7 @@ import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module):
@@ -63,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
device = self.computation_device if self.computation_device != "npu" else "npu:0"
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit

View File

@@ -7,6 +7,7 @@ from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelCon
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit:
@@ -177,7 +178,7 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
device = self.device if self.device != "npu" else "npu:0"
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):

View File

@@ -8,6 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.gradient import gradient_checkpoint_forward
@@ -315,7 +316,10 @@ class RopeEmbedder:
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)