mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
support loading models from state dict
This commit is contained in:
@@ -5,7 +5,7 @@ from .file import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
if state_dict is not None:
|
||||
pass
|
||||
elif use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
|
||||
Reference in New Issue
Block a user