mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
Merge pull request #1169 from Feng0w0/sample_add
Docs:Supplement NPU training script samples and documentation instruction
This commit is contained in:
@@ -1 +1,2 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||
from .npu_compatible_device import IS_NPU_AVAILABLE
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
from ..device import parse_device_type
|
||||
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
@@ -63,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
device = self.computation_device if self.computation_device != "npu" else "npu:0"
|
||||
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelCon
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
@@ -177,7 +178,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
device = self.device if self.device != "npu" else "npu:0"
|
||||
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
|
||||
Reference in New Issue
Block a user