Merge pull request #1240 from modelscope/loader-update

Loader update
This commit is contained in:
Zhongjie Duan
2026-01-30 13:51:17 +08:00
committed by GitHub
6 changed files with 28 additions and 14 deletions

View File

@@ -1,5 +1,5 @@
import torch, glob, os
from typing import Optional, Union
from typing import Optional, Union, Dict
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,6 +23,7 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self):
if self.path is None and self.model_id is None:

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -5,7 +5,7 @@ from .file import load_state_dict
import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)

View File

@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
)
return model_pool

View File

@@ -29,7 +29,7 @@ class ModelPool:
module_map = None
return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None):
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
@@ -43,6 +43,7 @@ class ModelPool:
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
)
return model
@@ -59,7 +60,7 @@ class ModelPool:
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
@@ -67,7 +68,7 @@ class ModelPool:
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "diffsynth"
version = "2.0.3"
version = "2.0.4"
description = "Enjoy the magic of Diffusion models!"
authors = [{name = "ModelScope Team"}]
license = {text = "Apache-2.0"}