import torch, os from safetensors import safe_open def load_state_dict_from_folder(file_path, torch_dtype=None): state_dict = {} for file_name in os.listdir(file_path): if "." in file_name and file_name.split(".")[-1] in [ "safetensors", "bin", "ckpt", "pth", "pt" ]: state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) return state_dict def load_state_dict(file_path, torch_dtype=None): if file_path.endswith(".safetensors"): return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) else: return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) def load_state_dict_from_safetensors(file_path, torch_dtype=None): state_dict = {} with safe_open(file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: state_dict[k] = state_dict[k].to(torch_dtype) return state_dict def load_state_dict_from_bin(file_path, torch_dtype=None): state_dict = torch.load(file_path, map_location="cpu") if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): state_dict[i] = state_dict[i].to(torch_dtype) return state_dict def search_for_embeddings(state_dict): embeddings = [] for k in state_dict: if isinstance(state_dict[k], torch.Tensor): embeddings.append(state_dict[k]) elif isinstance(state_dict[k], dict): embeddings += search_for_embeddings(state_dict[k]) return embeddings def search_parameter(param, state_dict): for name, param_ in state_dict.items(): if param.numel() == param_.numel(): if param.shape == param_.shape: if torch.dist(param, param_) < 1e-3: return name else: if torch.dist(param.flatten(), param_.flatten()) < 1e-3: return name return None def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): matched_keys = set() with torch.no_grad(): for name in source_state_dict: rename = search_parameter(source_state_dict[name], target_state_dict) if rename is not None: print(f'"{name}": "{rename}",') matched_keys.add(rename) elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0: length = source_state_dict[name].shape[0] // 3 rename = [] for i in range(3): rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict)) if None not in rename: print(f'"{name}": {rename},') for rename_ in rename: matched_keys.add(rename_) for name in target_state_dict: if name not in matched_keys: print("Cannot find", name, target_state_dict[name].shape) def search_for_files(folder, extensions): files = [] if os.path.isdir(folder): for file in sorted(os.listdir(folder)): files += search_for_files(os.path.join(folder, file), extensions) elif os.path.isfile(folder): for extension in extensions: if folder.endswith(extension): files.append(folder) break return files