mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support loading models from state dict
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch, glob, os
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Dict
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
@@ -23,6 +23,7 @@ class ModelConfig:
|
||||
computation_device: Optional[Union[str, torch.device]] = None
|
||||
computation_dtype: Optional[torch.dtype] = None
|
||||
clear_parameters: bool = False
|
||||
state_dict: Dict[str, torch.Tensor] = None
|
||||
|
||||
def check_input(self):
|
||||
if self.path is None and self.model_id is None:
|
||||
|
||||
@@ -2,16 +2,25 @@ from safetensors import safe_open
|
||||
import torch, hashlib
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [started]: {file_path}")
|
||||
if file_path.endswith(".safetensors"):
|
||||
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||
if pin_memory:
|
||||
for i in state_dict:
|
||||
state_dict[i] = state_dict[i].pin_memory()
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [done]: {file_path}")
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
|
||||
@@ -5,7 +5,7 @@ from .file import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
if state_dict is not None:
|
||||
pass
|
||||
elif use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
|
||||
@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
state_dict=model_config.state_dict,
|
||||
)
|
||||
return model_pool
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class ModelPool:
|
||||
module_map = None
|
||||
return module_map
|
||||
|
||||
def load_model_file(self, config, path, vram_config, vram_limit=None):
|
||||
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||
model_class = self.import_model_class(config["model_class"])
|
||||
model_config = config.get("extra_kwargs", {})
|
||||
if "state_dict_converter" in config:
|
||||
@@ -43,6 +43,7 @@ class ModelPool:
|
||||
state_dict_converter,
|
||||
use_disk_map=True,
|
||||
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
return model
|
||||
|
||||
@@ -59,7 +60,7 @@ class ModelPool:
|
||||
}
|
||||
return vram_config
|
||||
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||
if vram_config is None:
|
||||
vram_config = self.default_vram_config()
|
||||
@@ -67,7 +68,7 @@ class ModelPool:
|
||||
loaded = False
|
||||
for config in MODEL_CONFIGS:
|
||||
if config["model_hash"] == model_hash:
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||
if clear_parameters: self.clear_parameters(model)
|
||||
self.model.append(model)
|
||||
model_name = config["model_name"]
|
||||
|
||||
Reference in New Issue
Block a user