mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
import torch, os
|
|
from safetensors import safe_open
|
|
|
|
|
|
|
|
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
|
state_dict = {}
|
|
for file_name in os.listdir(file_path):
|
|
if "." in file_name and file_name.split(".")[-1] in [
|
|
"safetensors", "bin", "ckpt", "pth", "pt"
|
|
]:
|
|
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
|
return state_dict
|
|
|
|
|
|
def load_state_dict(file_path, torch_dtype=None):
|
|
if file_path.endswith(".safetensors"):
|
|
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
else:
|
|
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
|
|
|
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
state_dict = {}
|
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
for k in f.keys():
|
|
state_dict[k] = f.get_tensor(k)
|
|
if torch_dtype is not None:
|
|
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
return state_dict
|
|
|
|
|
|
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
state_dict = torch.load(file_path, map_location="cpu")
|
|
if torch_dtype is not None:
|
|
for i in state_dict:
|
|
if isinstance(state_dict[i], torch.Tensor):
|
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
return state_dict
|
|
|
|
|
|
def search_for_embeddings(state_dict):
|
|
embeddings = []
|
|
for k in state_dict:
|
|
if isinstance(state_dict[k], torch.Tensor):
|
|
embeddings.append(state_dict[k])
|
|
elif isinstance(state_dict[k], dict):
|
|
embeddings += search_for_embeddings(state_dict[k])
|
|
return embeddings
|
|
|
|
|
|
def search_parameter(param, state_dict):
|
|
for name, param_ in state_dict.items():
|
|
if param.numel() == param_.numel():
|
|
if param.shape == param_.shape:
|
|
if torch.dist(param, param_) < 1e-3:
|
|
return name
|
|
else:
|
|
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
|
return name
|
|
return None
|
|
|
|
|
|
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
matched_keys = set()
|
|
with torch.no_grad():
|
|
for name in source_state_dict:
|
|
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
if rename is not None:
|
|
print(f'"{name}": "{rename}",')
|
|
matched_keys.add(rename)
|
|
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
length = source_state_dict[name].shape[0] // 3
|
|
rename = []
|
|
for i in range(3):
|
|
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
if None not in rename:
|
|
print(f'"{name}": {rename},')
|
|
for rename_ in rename:
|
|
matched_keys.add(rename_)
|
|
for name in target_state_dict:
|
|
if name not in matched_keys:
|
|
print("Cannot find", name, target_state_dict[name].shape)
|
|
|
|
|
|
def search_for_files(folder, extensions):
|
|
files = []
|
|
if os.path.isdir(folder):
|
|
for file in sorted(os.listdir(folder)):
|
|
files += search_for_files(os.path.join(folder, file), extensions)
|
|
elif os.path.isfile(folder):
|
|
for extension in extensions:
|
|
if folder.endswith(extension):
|
|
files.append(folder)
|
|
break
|
|
return files
|