mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
22 lines
766 B
Python
22 lines
766 B
Python
import torch
|
|
from contextlib import contextmanager
|
|
|
|
|
|
@contextmanager
|
|
def skip_model_initialization(device=torch.device("meta")):
|
|
|
|
def register_empty_parameter(module, name, param):
|
|
old_register_parameter(module, name, param)
|
|
if param is not None:
|
|
param_cls = type(module._parameters[name])
|
|
kwargs = module._parameters[name].__dict__
|
|
kwargs["requires_grad"] = param.requires_grad
|
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
|
|
|
old_register_parameter = torch.nn.Module.register_parameter
|
|
torch.nn.Module.register_parameter = register_empty_parameter
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.nn.Module.register_parameter = old_register_parameter
|