From 02bbd18acf2bf030a7356d2adfc8e8d7d6cf1081 Mon Sep 17 00:00:00 2001 From: josc146 Date: Wed, 28 Feb 2024 23:25:46 +0800 Subject: [PATCH] fix convert_safetensors.py for rwkv6 --- backend-python/convert_safetensors.py | 28 ++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/backend-python/convert_safetensors.py b/backend-python/convert_safetensors.py index feb4874..876c4b5 100644 --- a/backend-python/convert_safetensors.py +++ b/backend-python/convert_safetensors.py @@ -54,19 +54,21 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names= loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) ) - for k in kk: - new_k = rename_key(rename, k).lower() - v = loaded[k].half() - del loaded[k] - for transpose_name in transpose_names: - if transpose_name in k: - v = v.transpose(0, 1) - print(f"{new_k}\t{v.shape}\t{v.dtype}") - loaded[new_k] = { - "dtype": str(v.dtype).split(".")[-1], - "shape": v.shape, - "data": v.numpy().tobytes(), - } + with torch.no_grad(): + for k in kk: + new_k = rename_key(rename, k).lower() + v = loaded[k].half() + del loaded[k] + for transpose_name in transpose_names: + if transpose_name in new_k: + dims = len(v.shape) + v = v.transpose(dims - 2, dims - 1) + print(f"{new_k}\t{v.shape}\t{v.dtype}") + loaded[new_k] = { + "dtype": str(v.dtype).split(".")[-1], + "shape": v.shape, + "data": v.numpy().tobytes(), + } dirname = os.path.dirname(sf_filename) os.makedirs(dirname, exist_ok=True)