support loading models from state dict

This commit is contained in:
Artiprocher
2026-01-30 13:47:36 +08:00
parent 22695e9be0
commit ee9a3b4405
5 changed files with 27 additions and 13 deletions

View File

@@ -1,5 +1,5 @@
import torch, glob, os import torch, glob, os
from typing import Optional, Union from typing import Optional, Union, Dict
from dataclasses import dataclass from dataclasses import dataclass
from modelscope import snapshot_download from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,6 +23,7 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self): def check_input(self):
if self.path is None and self.model_id is None: if self.path is None and self.model_id is None:

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list): if isinstance(file_path, list):
state_dict = {} state_dict = {}
for file_path_ in file_path: for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device)) state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else: else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -5,7 +5,7 @@ from .file import load_state_dict
import torch import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config config = {} if config is None else config
# Why do we use `skip_model_initialization`? # Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters, # It skips the random initialization of model parameters,
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0] dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk": if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype) if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None: if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict) state_dict = state_dict_converter(state_dict)
else: else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models, # Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model, # and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file. # avoiding the need to load all parameters in the file.
if use_disk_map: if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype) state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else: else:
state_dict = load_state_dict(path, torch_dtype, device) state_dict = load_state_dict(path, torch_dtype, device)

View File

@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config, vram_config=vram_config,
vram_limit=vram_limit, vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters, clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
) )
return model_pool return model_pool

View File

@@ -29,7 +29,7 @@ class ModelPool:
module_map = None module_map = None
return module_map return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None): def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"]) model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {}) model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config: if "state_dict_converter" in config:
@@ -43,6 +43,7 @@ class ModelPool:
state_dict_converter, state_dict_converter,
use_disk_map=True, use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
) )
return model return model
@@ -59,7 +60,7 @@ class ModelPool:
} }
return vram_config return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}") print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None: if vram_config is None:
vram_config = self.default_vram_config() vram_config = self.default_vram_config()
@@ -67,7 +68,7 @@ class ModelPool:
loaded = False loaded = False
for config in MODEL_CONFIGS: for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash: if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model) if clear_parameters: self.clear_parameters(model)
self.model.append(model) self.model.append(model)
model_name = config["model_name"] model_name = config["model_name"]