mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'modelscope:main' into wan_rope
This commit is contained in:
@@ -1 +1,2 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||
from .npu_compatible_device import IS_NPU_AVAILABLE
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
from ..device import parse_device_type
|
||||
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
@@ -63,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
device = self.computation_device if self.computation_device != "npu" else "npu:0"
|
||||
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelCon
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
@@ -177,7 +178,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
device = self.device if self.device != "npu" else "npu:0"
|
||||
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
|
||||
@@ -8,6 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -315,7 +316,10 @@ class RopeEmbedder:
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
if IS_NPU_AVAILABLE:
|
||||
result.append(torch.index_select(self.freqs_cis[i], 0, index))
|
||||
else:
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user