mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
accelerate load model
This commit is contained in:
@@ -50,7 +50,7 @@ from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict
|
||||
from .utils import load_state_dict, init_weights_on_device
|
||||
|
||||
|
||||
|
||||
@@ -106,8 +106,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
||||
model.load_state_dict(model_state_dict)
|
||||
with init_weights_on_device():
|
||||
model= model_class(**extra_kwargs)
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
Reference in New Issue
Block a user