mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support loading models from state dict
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import torch, glob, os
|
import torch, glob, os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, Dict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||||
@@ -23,6 +23,7 @@ class ModelConfig:
|
|||||||
computation_device: Optional[Union[str, torch.device]] = None
|
computation_device: Optional[Union[str, torch.device]] = None
|
||||||
computation_dtype: Optional[torch.dtype] = None
|
computation_dtype: Optional[torch.dtype] = None
|
||||||
clear_parameters: bool = False
|
clear_parameters: bool = False
|
||||||
|
state_dict: Dict[str, torch.Tensor] = None
|
||||||
|
|
||||||
def check_input(self):
|
def check_input(self):
|
||||||
if self.path is None and self.model_id is None:
|
if self.path is None and self.model_id is None:
|
||||||
|
|||||||
@@ -2,16 +2,25 @@ from safetensors import safe_open
|
|||||||
import torch, hashlib
|
import torch, hashlib
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||||
if isinstance(file_path, list):
|
if isinstance(file_path, list):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
for file_path_ in file_path:
|
for file_path_ in file_path:
|
||||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||||
return state_dict
|
|
||||||
if file_path.endswith(".safetensors"):
|
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
|
||||||
else:
|
else:
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [started]: {file_path}")
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||||
|
if pin_memory:
|
||||||
|
for i in state_dict:
|
||||||
|
state_dict[i] = state_dict[i].pin_memory()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [done]: {file_path}")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from .file import load_state_dict
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
config = {} if config is None else config
|
config = {} if config is None else config
|
||||||
# Why do we use `skip_model_initialization`?
|
# Why do we use `skip_model_initialization`?
|
||||||
# It skips the random initialization of model parameters,
|
# It skips the random initialization of model parameters,
|
||||||
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||||
dtype = [d for d in dtypes if d != "disk"][0]
|
dtype = [d for d in dtypes if d != "disk"][0]
|
||||||
if vram_config["offload_device"] != "disk":
|
if vram_config["offload_device"] != "disk":
|
||||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||||
if state_dict_converter is not None:
|
if state_dict_converter is not None:
|
||||||
state_dict = state_dict_converter(state_dict)
|
state_dict = state_dict_converter(state_dict)
|
||||||
else:
|
else:
|
||||||
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
# Sometimes a model file contains multiple models,
|
# Sometimes a model file contains multiple models,
|
||||||
# and DiskMap can load only the parameters of a single model,
|
# and DiskMap can load only the parameters of a single model,
|
||||||
# avoiding the need to load all parameters in the file.
|
# avoiding the need to load all parameters in the file.
|
||||||
if use_disk_map:
|
if state_dict is not None:
|
||||||
|
pass
|
||||||
|
elif use_disk_map:
|
||||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict(path, torch_dtype, device)
|
state_dict = load_state_dict(path, torch_dtype, device)
|
||||||
|
|||||||
@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
|
|||||||
vram_config=vram_config,
|
vram_config=vram_config,
|
||||||
vram_limit=vram_limit,
|
vram_limit=vram_limit,
|
||||||
clear_parameters=model_config.clear_parameters,
|
clear_parameters=model_config.clear_parameters,
|
||||||
|
state_dict=model_config.state_dict,
|
||||||
)
|
)
|
||||||
return model_pool
|
return model_pool
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class ModelPool:
|
|||||||
module_map = None
|
module_map = None
|
||||||
return module_map
|
return module_map
|
||||||
|
|
||||||
def load_model_file(self, config, path, vram_config, vram_limit=None):
|
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||||
model_class = self.import_model_class(config["model_class"])
|
model_class = self.import_model_class(config["model_class"])
|
||||||
model_config = config.get("extra_kwargs", {})
|
model_config = config.get("extra_kwargs", {})
|
||||||
if "state_dict_converter" in config:
|
if "state_dict_converter" in config:
|
||||||
@@ -43,6 +43,7 @@ class ModelPool:
|
|||||||
state_dict_converter,
|
state_dict_converter,
|
||||||
use_disk_map=True,
|
use_disk_map=True,
|
||||||
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||||
|
state_dict=state_dict,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -59,7 +60,7 @@ class ModelPool:
|
|||||||
}
|
}
|
||||||
return vram_config
|
return vram_config
|
||||||
|
|
||||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
|
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||||
if vram_config is None:
|
if vram_config is None:
|
||||||
vram_config = self.default_vram_config()
|
vram_config = self.default_vram_config()
|
||||||
@@ -67,7 +68,7 @@ class ModelPool:
|
|||||||
loaded = False
|
loaded = False
|
||||||
for config in MODEL_CONFIGS:
|
for config in MODEL_CONFIGS:
|
||||||
if config["model_hash"] == model_hash:
|
if config["model_hash"] == model_hash:
|
||||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
|
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||||
if clear_parameters: self.clear_parameters(model)
|
if clear_parameters: self.clear_parameters(model)
|
||||||
self.model.append(model)
|
self.model.append(model)
|
||||||
model_name = config["model_name"]
|
model_name = config["model_name"]
|
||||||
|
|||||||
Reference in New Issue
Block a user