fix convert_safetensors.py for rwkv6

This commit is contained in:
josc146 2024-02-28 23:25:46 +08:00
parent 18ab8b141f
commit 02bbd18acf

View File

@ -54,13 +54,15 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
) )
with torch.no_grad():
for k in kk: for k in kk:
new_k = rename_key(rename, k).lower() new_k = rename_key(rename, k).lower()
v = loaded[k].half() v = loaded[k].half()
del loaded[k] del loaded[k]
for transpose_name in transpose_names: for transpose_name in transpose_names:
if transpose_name in k: if transpose_name in new_k:
v = v.transpose(0, 1) dims = len(v.shape)
v = v.transpose(dims - 2, dims - 1)
print(f"{new_k}\t{v.shape}\t{v.dtype}") print(f"{new_k}\t{v.shape}\t{v.dtype}")
loaded[new_k] = { loaded[new_k] = {
"dtype": str(v.dtype).split(".")[-1], "dtype": str(v.dtype).split(".")[-1],