fix convert_safetensors.py for rwkv6

This commit is contained in:
josc146 2024-02-28 23:25:46 +08:00
parent 18ab8b141f
commit 02bbd18acf

View File

@ -54,19 +54,21 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
) )
for k in kk: with torch.no_grad():
new_k = rename_key(rename, k).lower() for k in kk:
v = loaded[k].half() new_k = rename_key(rename, k).lower()
del loaded[k] v = loaded[k].half()
for transpose_name in transpose_names: del loaded[k]
if transpose_name in k: for transpose_name in transpose_names:
v = v.transpose(0, 1) if transpose_name in new_k:
print(f"{new_k}\t{v.shape}\t{v.dtype}") dims = len(v.shape)
loaded[new_k] = { v = v.transpose(dims - 2, dims - 1)
"dtype": str(v.dtype).split(".")[-1], print(f"{new_k}\t{v.shape}\t{v.dtype}")
"shape": v.shape, loaded[new_k] = {
"data": v.numpy().tobytes(), "dtype": str(v.dtype).split(".")[-1],
} "shape": v.shape,
"data": v.numpy().tobytes(),
}
dirname = os.path.dirname(sf_filename) dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)