diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 5d9b052..1abfd02 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -3,24 +3,20 @@ from ..vram.disk_map import DiskMap from ..vram.layers import enable_vram_management from .file import load_state_dict import torch -from contextlib import contextmanager from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import ContextManagers -def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, - use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): config = {} if config is None else config with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)): model = model_class(**config) # What is `module_map`? # This is a module mapping table for VRAM management. if module_map is not None: - devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], - vram_config["computation_device"]] + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] device = [d for d in devices if d != "disk"][0] - dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], - vram_config["computation_dtype"]] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": state_dict = DiskMap(path, device, torch_dtype=dtype) @@ -29,12 +25,10 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic else: state_dict = {i: state_dict[i] for i in state_dict} model.load_state_dict(state_dict, assign=True) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, - vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) else: disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, - vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) else: # Why do we use `DiskMap`? # Sometimes a model file contains multiple models, @@ -51,6 +45,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic state_dict = state_dict_converter(state_dict) else: state_dict = {i: state_dict[i] for i in state_dict} + # Why does DeepSpeed ZeRO Stage 3 need to be handled separately? + # Because at this stage, model parameters are partitioned across multiple GPUs. + # Loading them directly could lead to excessive GPU memory consumption. if is_deepspeed_zero3_enabled(): from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model _load_state_dict_into_zero3_model(model, state_dict) @@ -65,8 +62,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic return model -def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", - state_dict_converter=None, module_map=None): +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): if isinstance(path, str): path = [path] config = {} if config is None else config