mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fe897883b |
@@ -3,24 +3,20 @@ from ..vram.disk_map import DiskMap
|
|||||||
from ..vram.layers import enable_vram_management
|
from ..vram.layers import enable_vram_management
|
||||||
from .file import load_state_dict
|
from .file import load_state_dict
|
||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils import ContextManagers
|
from transformers.utils import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None,
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
|
||||||
config = {} if config is None else config
|
config = {} if config is None else config
|
||||||
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||||
model = model_class(**config)
|
model = model_class(**config)
|
||||||
# What is `module_map`?
|
# What is `module_map`?
|
||||||
# This is a module mapping table for VRAM management.
|
# This is a module mapping table for VRAM management.
|
||||||
if module_map is not None:
|
if module_map is not None:
|
||||||
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"],
|
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||||
vram_config["computation_device"]]
|
|
||||||
device = [d for d in devices if d != "disk"][0]
|
device = [d for d in devices if d != "disk"][0]
|
||||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"],
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||||
vram_config["computation_dtype"]]
|
|
||||||
dtype = [d for d in dtypes if d != "disk"][0]
|
dtype = [d for d in dtypes if d != "disk"][0]
|
||||||
if vram_config["offload_device"] != "disk":
|
if vram_config["offload_device"] != "disk":
|
||||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||||
@@ -29,12 +25,10 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
else:
|
else:
|
||||||
state_dict = {i: state_dict[i] for i in state_dict}
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
model.load_state_dict(state_dict, assign=True)
|
model.load_state_dict(state_dict, assign=True)
|
||||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None,
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||||
vram_limit=vram_limit)
|
|
||||||
else:
|
else:
|
||||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map,
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||||
vram_limit=vram_limit)
|
|
||||||
else:
|
else:
|
||||||
# Why do we use `DiskMap`?
|
# Why do we use `DiskMap`?
|
||||||
# Sometimes a model file contains multiple models,
|
# Sometimes a model file contains multiple models,
|
||||||
@@ -51,6 +45,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
state_dict = state_dict_converter(state_dict)
|
state_dict = state_dict_converter(state_dict)
|
||||||
else:
|
else:
|
||||||
state_dict = {i: state_dict[i] for i in state_dict}
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||||
|
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||||
|
# Loading them directly could lead to excessive GPU memory consumption.
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
_load_state_dict_into_zero3_model(model, state_dict)
|
_load_state_dict_into_zero3_model(model, state_dict)
|
||||||
@@ -65,8 +62,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu",
|
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||||
state_dict_converter=None, module_map=None):
|
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = [path]
|
path = [path]
|
||||||
config = {} if config is None else config
|
config = {} if config is None else config
|
||||||
|
|||||||
Reference in New Issue
Block a user