170 lines
4.9 KiB
Python
Vendored
170 lines
4.9 KiB
Python
Vendored
# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
|
|
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16
|
|
# Get model checkpoints from https://huggingface.co/BlinkDL
|
|
# See FILE_FORMAT.md for the documentation on the file format.
|
|
|
|
import argparse
|
|
import struct
|
|
import torch
|
|
from typing import Dict
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file"
|
|
)
|
|
parser.add_argument("src_path", help="Path to PyTorch checkpoint file")
|
|
parser.add_argument(
|
|
"dest_path", help="Path to rwkv.cpp checkpoint file, will be overwritten"
|
|
)
|
|
parser.add_argument(
|
|
"data_type",
|
|
help="Data type, FP16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0",
|
|
type=str,
|
|
choices=[
|
|
"FP16",
|
|
"Q4_0",
|
|
"Q4_1",
|
|
"Q5_0",
|
|
"Q5_1",
|
|
"Q8_0",
|
|
],
|
|
default="FP16",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
|
|
n_layer: int = 0
|
|
|
|
while f"blocks.{n_layer}.ln1.weight" in state_dict:
|
|
n_layer += 1
|
|
|
|
assert n_layer > 0
|
|
|
|
return n_layer
|
|
|
|
|
|
def write_state_dict(
|
|
state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str
|
|
) -> None:
|
|
emb_weight: torch.Tensor = state_dict["emb.weight"]
|
|
|
|
n_layer: int = get_layer_count(state_dict)
|
|
n_vocab: int = emb_weight.shape[0]
|
|
n_embed: int = emb_weight.shape[1]
|
|
|
|
is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict
|
|
is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict
|
|
|
|
if is_v5_2:
|
|
print("Detected RWKV v5.2")
|
|
elif is_v5_1_or_2:
|
|
print("Detected RWKV v5.1")
|
|
else:
|
|
print("Detected RWKV v4")
|
|
|
|
with open(dest_path, "wb") as out_file:
|
|
is_FP16: bool = data_type == "FP16" or data_type == "float16"
|
|
|
|
out_file.write(
|
|
struct.pack(
|
|
# Disable padding with '='
|
|
"=iiiiii",
|
|
# Magic: 'ggmf' in hex
|
|
0x67676D66,
|
|
101,
|
|
n_vocab,
|
|
n_embed,
|
|
n_layer,
|
|
1 if is_FP16 else 0,
|
|
)
|
|
)
|
|
|
|
for k in state_dict.keys():
|
|
tensor: torch.Tensor = state_dict[k].float()
|
|
|
|
if ".time_" in k:
|
|
tensor = tensor.squeeze()
|
|
|
|
if is_v5_1_or_2:
|
|
if ".time_decay" in k:
|
|
if is_v5_2:
|
|
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
|
|
else:
|
|
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
|
|
|
|
if ".time_first" in k:
|
|
tensor = torch.exp(tensor).reshape(-1, 1, 1)
|
|
|
|
if ".time_faaaa" in k:
|
|
tensor = tensor.unsqueeze(-1)
|
|
else:
|
|
if ".time_decay" in k:
|
|
tensor = -torch.exp(tensor)
|
|
|
|
# Keep 1-dim vectors and small matrices in FP32
|
|
if is_FP16 and len(tensor.shape) > 1 and ".time_" not in k:
|
|
tensor = tensor.half()
|
|
|
|
shape = tensor.shape
|
|
|
|
print(f"Writing {k}, shape {shape}, type {tensor.dtype}")
|
|
|
|
k_encoded: bytes = k.encode("utf-8")
|
|
|
|
out_file.write(
|
|
struct.pack(
|
|
"=iii",
|
|
len(shape),
|
|
len(k_encoded),
|
|
1 if tensor.dtype == torch.float16 else 0,
|
|
)
|
|
)
|
|
|
|
# Dimension order is reversed here:
|
|
# * PyTorch shape is (x rows, y columns)
|
|
# * ggml shape is (y elements in a row, x elements in a column)
|
|
# Both shapes represent the same tensor.
|
|
for dim in reversed(tensor.shape):
|
|
out_file.write(struct.pack("=i", dim))
|
|
|
|
out_file.write(k_encoded)
|
|
|
|
tensor.numpy().tofile(out_file)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
print(f"Reading {args.src_path}")
|
|
|
|
state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location="cpu")
|
|
|
|
temp_output: str = args.dest_path
|
|
if args.data_type.startswith("Q"):
|
|
import re
|
|
|
|
temp_output = re.sub(r"Q[4,5,8]_[0,1]", "fp16", temp_output)
|
|
write_state_dict(state_dict, temp_output, "FP16")
|
|
if args.data_type.startswith("Q"):
|
|
import sys
|
|
import os
|
|
|
|
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
|
from rwkv_pip.cpp import rwkv_cpp_shared_library
|
|
|
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
|
library.rwkv_quantize_model_file(temp_output, args.dest_path, args.data_type)
|
|
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
print(e)
|
|
with open("error.txt", "w") as f:
|
|
f.write(str(e))
|