support loading models from state dict

This commit is contained in:
Artiprocher
2026-01-30 13:47:36 +08:00
parent 22695e9be0
commit ee9a3b4405
5 changed files with 27 additions and 13 deletions

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):