From 053a08f5b71beba4cd866e86577f9274106a51a3 Mon Sep 17 00:00:00 2001 From: josc146 Date: Wed, 6 Dec 2023 23:08:40 +0800 Subject: [PATCH] update convert_safetensors.py --- backend-python/convert_safetensors.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/backend-python/convert_safetensors.py b/backend-python/convert_safetensors.py index bb54de4..6637b2e 100644 --- a/backend-python/convert_safetensors.py +++ b/backend-python/convert_safetensors.py @@ -25,7 +25,7 @@ def rename_key(rename, name): return name -def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}): +def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]): loaded = torch.load(pt_filename, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] @@ -34,12 +34,14 @@ def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename= # for k, v in loaded.items(): # print(f'{k}\t{v.shape}\t{v.dtype}') + loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()} # For tensors to be contiguous for k, v in loaded.items(): for transpose_name in transpose_names: if transpose_name in k: loaded[k] = v.transpose(0, 1) - loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()} + + loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()} for k, v in loaded.items(): print(f"{k}\t{v.shape}\t{v.dtype}") @@ -60,8 +62,18 @@ if __name__ == "__main__": convert_file( args.input, args.output, - ["lora_A"], - {"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"}, + rename={ + "time_faaaa": "time_first", + "time_maa": "time_mix", + "lora_A": "lora.0", + "lora_B": "lora.1", + }, + transpose_names=[ + "time_mix_w1", + "time_mix_w2", + "time_decay_w1", + "time_decay_w2", + ], ) print(f"Saved to {args.output}") except Exception as e: