update doc and code

This commit is contained in:
Artiprocher
2025-11-05 20:37:11 +08:00
parent 3afecc65fc
commit 6a6eca7baf
3 changed files with 247 additions and 4 deletions

View File

@@ -26,6 +26,11 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if len(state_dict) == 1:
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "module" in state_dict:
state_dict = state_dict["module"]
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
@@ -75,11 +80,19 @@ def load_keys_dict_from_safetensors(file_path):
return keys_dict
def load_keys_dict_from_bin(file_path):
state_dict = load_state_dict_from_bin(file_path)
def convert_state_dict_to_keys_dict(state_dict):
keys_dict = {}
for k, v in state_dict.items():
keys_dict[k] = list(v.shape)
if isinstance(v, torch.Tensor):
keys_dict[k] = list(v.shape)
else:
keys_dict[k] = convert_state_dict_to_keys_dict(v)
return keys_dict
def load_keys_dict_from_bin(file_path):
state_dict = load_state_dict_from_bin(file_path)
keys_dict = convert_state_dict_to_keys_dict(state_dict)
return keys_dict
@@ -88,7 +101,7 @@ def convert_keys_dict_to_single_str(state_dict, with_shape=True):
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
else:
if with_shape:
shape = "_".join(map(str, list(value)))