mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
104 lines
5.1 KiB
Python
104 lines
5.1 KiB
Python
from ..vram.initialization import skip_model_initialization
|
|
from ..vram.disk_map import DiskMap
|
|
from ..vram.layers import enable_vram_management
|
|
from .file import load_state_dict
|
|
import torch
|
|
from contextlib import contextmanager
|
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
|
from transformers.utils import ContextManagers
|
|
|
|
|
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
|
config = {} if config is None else config
|
|
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
|
model = model_class(**config)
|
|
# What is `module_map`?
|
|
# This is a module mapping table for VRAM management.
|
|
if module_map is not None:
|
|
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
|
device = [d for d in devices if d != "disk"][0]
|
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
|
dtype = [d for d in dtypes if d != "disk"][0]
|
|
if vram_config["offload_device"] != "disk":
|
|
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
|
if state_dict_converter is not None:
|
|
state_dict = state_dict_converter(state_dict)
|
|
else:
|
|
state_dict = {i: state_dict[i] for i in state_dict}
|
|
model.load_state_dict(state_dict, assign=True)
|
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
|
else:
|
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
|
else:
|
|
# Why do we use `DiskMap`?
|
|
# Sometimes a model file contains multiple models,
|
|
# and DiskMap can load only the parameters of a single model,
|
|
# avoiding the need to load all parameters in the file.
|
|
if use_disk_map:
|
|
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
|
else:
|
|
state_dict = load_state_dict(path, torch_dtype, device)
|
|
# Why do we use `state_dict_converter`?
|
|
# Some models are saved in complex formats,
|
|
# and we need to convert the state dict into the appropriate format.
|
|
if state_dict_converter is not None:
|
|
state_dict = state_dict_converter(state_dict)
|
|
else:
|
|
state_dict = {i: state_dict[i] for i in state_dict}
|
|
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
|
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
|
# Loading them directly could lead to excessive GPU memory consumption.
|
|
if is_deepspeed_zero3_enabled():
|
|
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
|
_load_state_dict_into_zero3_model(model, state_dict)
|
|
else:
|
|
model.load_state_dict(state_dict, assign=True)
|
|
# Why do we call `to()`?
|
|
# Because some models override the behavior of `to()`,
|
|
# especially those from libraries like Transformers.
|
|
model = model.to(dtype=torch_dtype, device=device)
|
|
if hasattr(model, "eval"):
|
|
model = model.eval()
|
|
return model
|
|
|
|
|
|
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
|
if isinstance(path, str):
|
|
path = [path]
|
|
config = {} if config is None else config
|
|
with skip_model_initialization():
|
|
model = model_class(**config)
|
|
if hasattr(model, "eval"):
|
|
model = model.eval()
|
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
|
vram_config = {
|
|
"offload_dtype": "disk",
|
|
"offload_device": "disk",
|
|
"onload_dtype": "disk",
|
|
"onload_device": "disk",
|
|
"preparing_dtype": torch.float8_e4m3fn,
|
|
"preparing_device": device,
|
|
"computation_dtype": torch_dtype,
|
|
"computation_device": device,
|
|
}
|
|
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
|
return model
|
|
|
|
|
|
def get_init_context(torch_dtype, device):
|
|
if is_deepspeed_zero3_enabled():
|
|
from transformers.modeling_utils import set_zero3_state
|
|
import deepspeed
|
|
# Why do we use "deepspeed.zero.Init"?
|
|
# Weight segmentation of the model can be performed on the CPU side
|
|
# and loading the segmented weights onto the computing card
|
|
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
|
else:
|
|
# Why do we use `skip_model_initialization`?
|
|
# It skips the random initialization of model parameters,
|
|
# thereby speeding up model loading and avoiding excessive memory usage.
|
|
init_contexts = [skip_model_initialization()]
|
|
|
|
return init_contexts
|