mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1272 from modelscope/zero3-fix
Support DeepSpeed ZeRO 3
This commit is contained in:
@@ -3,14 +3,14 @@ from ..vram.disk_map import DiskMap
|
||||
from ..vram.layers import enable_vram_management
|
||||
from .file import load_state_dict
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
with skip_model_initialization():
|
||||
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||
model = model_class(**config)
|
||||
# What is `module_map`?
|
||||
# This is a module mapping table for VRAM management.
|
||||
@@ -48,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||
# Loading them directly could lead to excessive GPU memory consumption.
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
_load_state_dict_into_zero3_model(model, state_dict)
|
||||
else:
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why do we call `to()`?
|
||||
# Because some models override the behavior of `to()`,
|
||||
# especially those from libraries like Transformers.
|
||||
@@ -79,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
|
||||
}
|
||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||
return model
|
||||
|
||||
|
||||
def get_init_context(torch_dtype, device):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.modeling_utils import set_zero3_state
|
||||
import deepspeed
|
||||
# Why do we use "deepspeed.zero.Init"?
|
||||
# Weight segmentation of the model can be performed on the CPU side
|
||||
# and loading the segmented weights onto the computing card
|
||||
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||
else:
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
init_contexts = [skip_model_initialization()]
|
||||
|
||||
return init_contexts
|
||||
|
||||
Reference in New Issue
Block a user