mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
accelerate load model
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device
|
||||
|
||||
|
||||
|
||||
@@ -466,10 +467,11 @@ class FluxDiT(torch.nn.Module):
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||
new_layer.weight.data = module.weight.data
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias.data = module.bias.data
|
||||
new_layer.bias = module.bias
|
||||
# del module
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
|
||||
@@ -50,7 +50,7 @@ from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict
|
||||
from .utils import load_state_dict, init_weights_on_device
|
||||
|
||||
|
||||
|
||||
@@ -106,8 +106,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
||||
model.load_state_dict(model_state_dict)
|
||||
with init_weights_on_device():
|
||||
model= model_class(**extra_kwargs)
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
@@ -1,7 +1,55 @@
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = torch.nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
try:
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
|
||||
Reference in New Issue
Block a user