diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index 98b4a61..e4b3d8b 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -28,8 +28,7 @@ func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) ( func (a *App) StartWebGPUServer(port int, host string) (string, error) { args := []string{"./backend-rust/webgpu_server"} - args = append(args, "-a", "0", "-t", "backend-rust/assets/rwkv_vocab_v20230424.json", - "--port", strconv.Itoa(port), "--ip", host) + args = append(args, "--port", strconv.Itoa(port), "--ip", host) return Cmd(args...) } diff --git a/backend-python/convert_safetensors.py b/backend-python/convert_safetensors.py index 017f9cf..7d70994 100644 --- a/backend-python/convert_safetensors.py +++ b/backend-python/convert_safetensors.py @@ -18,20 +18,31 @@ parser.add_argument( args = parser.parse_args() -def convert_file( - pt_filename: str, - sf_filename: str, -): +def rename_key(rename, name): + for k, v in rename.items(): + if k in name: + name = name.replace(k, v) + return name + + +def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}): loaded = torch.load(pt_filename, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] loaded = {k: v.clone().half() for k, v in loaded.items()} - for k, v in loaded.items(): - print(f"{k}\t{v.shape}\t{v.dtype}") + # for k, v in loaded.items(): + # print(f'{k}\t{v.shape}\t{v.dtype}') # For tensors to be contiguous - loaded = {k: v.contiguous() for k, v in loaded.items()} + for k, v in loaded.items(): + for transpose_name in transpose_names: + if transpose_name in k: + loaded[k] = v.transpose(0, 1) + loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()} + + for k, v in loaded.items(): + print(f"{k}\t{v.shape}\t{v.dtype}") dirname = os.path.dirname(sf_filename) os.makedirs(dirname, exist_ok=True) @@ -46,7 +57,12 @@ def convert_file( if __name__ == "__main__": try: - convert_file(args.input, args.output) + convert_file( + args.input, + args.output, + ["lora_A"], + {"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"}, + ) print(f"Saved to {args.output}") except Exception as e: with open("error.txt", "w") as f: