diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py index 88b46a0..010483d 100644 --- a/diffsynth/core/loader/config.py +++ b/diffsynth/core/loader/config.py @@ -1,5 +1,5 @@ import torch, glob, os -from typing import Optional, Union +from typing import Optional, Union, Dict from dataclasses import dataclass from modelscope import snapshot_download from huggingface_hub import snapshot_download as hf_snapshot_download @@ -23,6 +23,7 @@ class ModelConfig: computation_device: Optional[Union[str, torch.device]] = None computation_dtype: Optional[torch.dtype] = None clear_parameters: bool = False + state_dict: Dict[str, torch.Tensor] = None def check_input(self): if self.path is None and self.model_id is None: diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py index 8f66961..67d8815 100644 --- a/diffsynth/core/loader/file.py +++ b/diffsynth/core/loader/file.py @@ -2,16 +2,25 @@ from safetensors import safe_open import torch, hashlib -def load_state_dict(file_path, torch_dtype=None, device="cpu"): +def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0): if isinstance(file_path, list): state_dict = {} for file_path_ in file_path: - state_dict.update(load_state_dict(file_path_, torch_dtype, device)) - return state_dict - if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose)) else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + if verbose >= 1: + print(f"Loading file [started]: {file_path}") + if file_path.endswith(".safetensors"): + state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + else: + state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + # If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster. + if pin_memory: + for i in state_dict: + state_dict[i] = state_dict[i].pin_memory() + if verbose >= 1: + print(f"Loading file [done]: {file_path}") + return state_dict def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 56fa7d3..1f920ab 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -5,7 +5,7 @@ from .file import load_state_dict import torch -def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): config = {} if config is None else config # Why do we use `skip_model_initialization`? # It skips the random initialization of model parameters, @@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": - state_dict = DiskMap(path, device, torch_dtype=dtype) + if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype) if state_dict_converter is not None: state_dict = state_dict_converter(state_dict) else: @@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic # Sometimes a model file contains multiple models, # and DiskMap can load only the parameters of a single model, # avoiding the need to load all parameters in the file. - if use_disk_map: + if state_dict is not None: + pass + elif use_disk_map: state_dict = DiskMap(path, device, torch_dtype=torch_dtype) else: state_dict = load_state_dict(path, torch_dtype, device) diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index d4731fd..094ec55 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module): vram_config=vram_config, vram_limit=vram_limit, clear_parameters=model_config.clear_parameters, + state_dict=model_config.state_dict, ) return model_pool diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py index 16d72dd..6a58c89 100644 --- a/diffsynth/models/model_loader.py +++ b/diffsynth/models/model_loader.py @@ -29,7 +29,7 @@ class ModelPool: module_map = None return module_map - def load_model_file(self, config, path, vram_config, vram_limit=None): + def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None): model_class = self.import_model_class(config["model_class"]) model_config = config.get("extra_kwargs", {}) if "state_dict_converter" in config: @@ -43,6 +43,7 @@ class ModelPool: state_dict_converter, use_disk_map=True, vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, + state_dict=state_dict, ) return model @@ -59,7 +60,7 @@ class ModelPool: } return vram_config - def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): + def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None): print(f"Loading models from: {json.dumps(path, indent=4)}") if vram_config is None: vram_config = self.default_vram_config() @@ -67,7 +68,7 @@ class ModelPool: loaded = False for config in MODEL_CONFIGS: if config["model_hash"] == model_hash: - model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) + model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict) if clear_parameters: self.clear_parameters(model) self.model.append(model) model_name = config["model_name"] diff --git a/pyproject.toml b/pyproject.toml index de82279..9a5075b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "diffsynth" -version = "2.0.3" +version = "2.0.4" description = "Enjoy the magic of Diffusion models!" authors = [{name = "ModelScope Team"}] license = {text = "Apache-2.0"}