resolve conflicts

This commit is contained in:
Artiprocher
2026-02-04 17:07:30 +08:00
parent ca9b5e64ea
commit 6fe897883b

View File

@@ -3,24 +3,20 @@ from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None,
use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
if module_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"],
vram_config["computation_device"]]
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"],
vram_config["computation_dtype"]]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
@@ -29,12 +25,10 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None,
vram_limit=vram_limit)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
else:
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map,
vram_limit=vram_limit)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
else:
# Why do we use `DiskMap`?
# Sometimes a model file contains multiple models,
@@ -51,6 +45,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
# Because at this stage, model parameters are partitioned across multiple GPUs.
# Loading them directly could lead to excessive GPU memory consumption.
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
@@ -65,8 +62,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
return model
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu",
state_dict_converter=None, module_map=None):
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
if isinstance(path, str):
path = [path]
config = {} if config is None else config