diff --git a/.gitattributes b/.gitattributes
index 5493593..965add0 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -3,6 +3,7 @@ backend-python/wkv_cuda_utils/** linguist-vendored
backend-python/get-pip.py linguist-vendored
backend-python/convert_model.py linguist-vendored
backend-python/convert_safetensors.py linguist-vendored
+backend-python/convert_pytorch_to_ggml.py linguist-vendored
backend-python/utils/midi.py linguist-vendored
build/** linguist-vendored
finetune/lora/** linguist-vendored
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index a8739fa..778f021 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -65,6 +65,8 @@ jobs:
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../libs" -Destination "py310/libs" -Recurse
./py310/python -m pip install cyac==1.9
go install github.com/wailsapp/wails/v2/cmd/wails@latest
+ del ./backend-python/rwkv_pip/cpp/librwkv.dylib
+ del ./backend-python/rwkv_pip/cpp/librwkv.so
(Get-Content -Path ./backend-golang/app.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/app.go
make
Rename-Item -Path "build/bin/RWKV-Runner.exe" -NewName "RWKV-Runner_windows_x64.exe"
@@ -93,6 +95,8 @@ jobs:
rm ./backend-python/rwkv_pip/rwkv6.pyd
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
rm ./backend-python/get-pip.py
+ rm ./backend-python/rwkv_pip/cpp/librwkv.dylib
+ rm ./backend-python/rwkv_pip/cpp/rwkv.dll
make
mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64
@@ -117,6 +121,8 @@ jobs:
rm ./backend-python/rwkv_pip/rwkv6.pyd
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
rm ./backend-python/get-pip.py
+ rm ./backend-python/rwkv_pip/cpp/rwkv.dll
+ rm ./backend-python/rwkv_pip/cpp/librwkv.so
make
cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt
cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal
diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go
index 482bd9a..a6c5d8f 100644
--- a/backend-golang/rwkv.go
+++ b/backend-golang/rwkv.go
@@ -10,7 +10,7 @@ import (
"strings"
)
-func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool) (string, error) {
+func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool) (string, error) {
var err error
if python == "" {
python, err = GetPython()
@@ -25,6 +25,9 @@ func (a *App) StartServer(python string, port int, host string, webui bool, rwkv
if rwkvBeta {
args = append(args, "--rwkv-beta")
}
+ if rwkvcpp {
+ args = append(args, "--rwkv.cpp")
+ }
args = append(args, "--port", strconv.Itoa(port), "--host", host)
return Cmd(args...)
}
@@ -52,6 +55,21 @@ func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, erro
return Cmd(args...)
}
+func (a *App) ConvertGGML(python string, modelPath string, outPath string, Q51 bool) (string, error) {
+ var err error
+ if python == "" {
+ python, err = GetPython()
+ }
+ if err != nil {
+ return "", err
+ }
+ dataType := "FP16"
+ if Q51 {
+ dataType = "Q5_1"
+ }
+ return Cmd(python, "./backend-python/convert_pytorch_to_ggml.py", modelPath, outPath, dataType)
+}
+
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {
var err error
if python == "" {
diff --git a/backend-python/convert_pytorch_to_ggml.py b/backend-python/convert_pytorch_to_ggml.py
new file mode 100644
index 0000000..6c42e35
--- /dev/null
+++ b/backend-python/convert_pytorch_to_ggml.py
@@ -0,0 +1,169 @@
+# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
+# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16
+# Get model checkpoints from https://huggingface.co/BlinkDL
+# See FILE_FORMAT.md for the documentation on the file format.
+
+import argparse
+import struct
+import torch
+from typing import Dict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file"
+ )
+ parser.add_argument("src_path", help="Path to PyTorch checkpoint file")
+ parser.add_argument(
+ "dest_path", help="Path to rwkv.cpp checkpoint file, will be overwritten"
+ )
+ parser.add_argument(
+ "data_type",
+ help="Data type, FP16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0",
+ type=str,
+ choices=[
+ "FP16",
+ "Q4_0",
+ "Q4_1",
+ "Q5_0",
+ "Q5_1",
+ "Q8_0",
+ ],
+ default="FP16",
+ )
+ return parser.parse_args()
+
+
+def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
+ n_layer: int = 0
+
+ while f"blocks.{n_layer}.ln1.weight" in state_dict:
+ n_layer += 1
+
+ assert n_layer > 0
+
+ return n_layer
+
+
+def write_state_dict(
+ state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str
+) -> None:
+ emb_weight: torch.Tensor = state_dict["emb.weight"]
+
+ n_layer: int = get_layer_count(state_dict)
+ n_vocab: int = emb_weight.shape[0]
+ n_embed: int = emb_weight.shape[1]
+
+ is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict
+ is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict
+
+ if is_v5_2:
+ print("Detected RWKV v5.2")
+ elif is_v5_1_or_2:
+ print("Detected RWKV v5.1")
+ else:
+ print("Detected RWKV v4")
+
+ with open(dest_path, "wb") as out_file:
+ is_FP16: bool = data_type == "FP16" or data_type == "float16"
+
+ out_file.write(
+ struct.pack(
+ # Disable padding with '='
+ "=iiiiii",
+ # Magic: 'ggmf' in hex
+ 0x67676D66,
+ 101,
+ n_vocab,
+ n_embed,
+ n_layer,
+ 1 if is_FP16 else 0,
+ )
+ )
+
+ for k in state_dict.keys():
+ tensor: torch.Tensor = state_dict[k].float()
+
+ if ".time_" in k:
+ tensor = tensor.squeeze()
+
+ if is_v5_1_or_2:
+ if ".time_decay" in k:
+ if is_v5_2:
+ tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
+ else:
+ tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
+
+ if ".time_first" in k:
+ tensor = torch.exp(tensor).reshape(-1, 1, 1)
+
+ if ".time_faaaa" in k:
+ tensor = tensor.unsqueeze(-1)
+ else:
+ if ".time_decay" in k:
+ tensor = -torch.exp(tensor)
+
+ # Keep 1-dim vectors and small matrices in FP32
+ if is_FP16 and len(tensor.shape) > 1 and ".time_" not in k:
+ tensor = tensor.half()
+
+ shape = tensor.shape
+
+ print(f"Writing {k}, shape {shape}, type {tensor.dtype}")
+
+ k_encoded: bytes = k.encode("utf-8")
+
+ out_file.write(
+ struct.pack(
+ "=iii",
+ len(shape),
+ len(k_encoded),
+ 1 if tensor.dtype == torch.float16 else 0,
+ )
+ )
+
+ # Dimension order is reversed here:
+ # * PyTorch shape is (x rows, y columns)
+ # * ggml shape is (y elements in a row, x elements in a column)
+ # Both shapes represent the same tensor.
+ for dim in reversed(tensor.shape):
+ out_file.write(struct.pack("=i", dim))
+
+ out_file.write(k_encoded)
+
+ tensor.numpy().tofile(out_file)
+
+
+def main() -> None:
+ args = parse_args()
+
+ print(f"Reading {args.src_path}")
+
+ state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location="cpu")
+
+ temp_output: str = args.dest_path
+ if args.data_type.startswith("Q"):
+ import re
+
+ temp_output = re.sub(r"Q[4,5,8]_[0,1]", "fp16", temp_output)
+ write_state_dict(state_dict, temp_output, "FP16")
+ if args.data_type.startswith("Q"):
+ import sys
+ import os
+
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
+ from rwkv_pip.cpp import rwkv_cpp_shared_library
+
+ library = rwkv_cpp_shared_library.load_rwkv_shared_library()
+ library.rwkv_quantize_model_file(temp_output, args.dest_path, args.data_type)
+
+ print("Done")
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except Exception as e:
+ print(e)
+ with open("error.txt", "w") as f:
+ f.write(str(e))
diff --git a/backend-python/main.py b/backend-python/main.py
index d098b1b..1ef38c9 100644
--- a/backend-python/main.py
+++ b/backend-python/main.py
@@ -32,6 +32,11 @@ def get_args(args: Union[Sequence[str], None] = None):
action="store_true",
help="whether to use rwkv-beta (default: False)",
)
+ group.add_argument(
+ "--rwkv.cpp",
+ action="store_true",
+ help="whether to use rwkv.cpp (default: False)",
+ )
args = parser.parse_args(args)
return args
diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py
index da5bf28..93da367 100644
--- a/backend-python/routes/config.py
+++ b/backend-python/routes/config.py
@@ -49,19 +49,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "":
return "success"
- STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
- if not re.match(STRATEGY_REGEX, body.strategy):
- raise HTTPException(
- Status.HTTP_400_BAD_REQUEST,
- "Invalid strategy. Please read https://pypi.org/project/rwkv/",
- )
devices = set(
[
x.strip().split(" ")[0].replace("cuda:0", "cuda")
for x in body.strategy.split("->")
]
)
- print(f"Devices: {devices}")
+ print(f"Strategy Devices: {devices}")
# if len(devices) > 1:
# state_cache.disable_state_cache()
# else:
diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py
index 5d02f57..16693c4 100644
--- a/backend-python/routes/state_cache.py
+++ b/backend-python/routes/state_cache.py
@@ -90,10 +90,15 @@ def add_state(body: AddStateBody):
try:
id: int = trie.insert(body.prompt)
- devices: List[torch.device] = [tensor.device for tensor in body.state]
+ devices: List[torch.device] = [
+ (tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
+ for tensor in body.state
+ ]
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
- "state": [tensor.cpu() for tensor in body.state],
+ "state": [tensor.cpu() for tensor in body.state]
+ if hasattr(body.state[0], "device")
+ else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"devices": devices,
}
@@ -185,7 +190,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
return {
"prompt": prompt,
"tokens": v["tokens"],
- "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
+ "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
+ if hasattr(v["state"][0], "device")
+ else v["state"],
"logits": v["logits"],
}
else:
diff --git a/backend-python/rwkv_pip/cpp/librwkv.dylib b/backend-python/rwkv_pip/cpp/librwkv.dylib
new file mode 100644
index 0000000..aeae5b2
Binary files /dev/null and b/backend-python/rwkv_pip/cpp/librwkv.dylib differ
diff --git a/backend-python/rwkv_pip/cpp/librwkv.so b/backend-python/rwkv_pip/cpp/librwkv.so
new file mode 100644
index 0000000..ed13248
Binary files /dev/null and b/backend-python/rwkv_pip/cpp/librwkv.so differ
diff --git a/backend-python/rwkv_pip/cpp/model.py b/backend-python/rwkv_pip/cpp/model.py
new file mode 100644
index 0000000..1a5a074
--- /dev/null
+++ b/backend-python/rwkv_pip/cpp/model.py
@@ -0,0 +1,14 @@
+from typing import Any, List
+from . import rwkv_cpp_model
+from . import rwkv_cpp_shared_library
+
+
+class RWKV:
+ def __init__(self, model_path: str, strategy=None):
+ self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
+ self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
+ self.w = {} # fake weight
+ self.w["emb.weight"] = [0] * self.model.n_vocab
+
+ def forward(self, tokens: List[int], state: Any | None):
+ return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
diff --git a/backend-python/rwkv_pip/cpp/rwkv.dll b/backend-python/rwkv_pip/cpp/rwkv.dll
new file mode 100644
index 0000000..b84a98c
Binary files /dev/null and b/backend-python/rwkv_pip/cpp/rwkv.dll differ
diff --git a/backend-python/rwkv_pip/cpp/rwkv_cpp_model.py b/backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
new file mode 100644
index 0000000..4b78c76
--- /dev/null
+++ b/backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
@@ -0,0 +1,369 @@
+import os
+import multiprocessing
+
+# Pre-import PyTorch, if available.
+# This fixes "OSError: [WinError 127] The specified procedure could not be found".
+try:
+ import torch
+except ModuleNotFoundError:
+ pass
+
+# I'm sure this is not strictly correct, but let's keep this crutch for now.
+try:
+ import rwkv_cpp_shared_library
+except ModuleNotFoundError:
+ from . import rwkv_cpp_shared_library
+
+from typing import TypeVar, Optional, Tuple, List
+
+# A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
+NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
+
+class RWKVModel:
+ """
+ An RWKV model managed by rwkv.cpp library.
+ """
+
+ def __init__(
+ self,
+ shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
+ model_path: str,
+ thread_count: int = max(1, multiprocessing.cpu_count() // 2),
+ gpu_layer_count: int = 0,
+ **kwargs
+ ) -> None:
+ """
+ Loads the model and prepares it for inference.
+ In case of any error, this method will throw an exception.
+
+ Parameters
+ ----------
+ shared_library : RWKVSharedLibrary
+ rwkv.cpp shared library.
+ model_path : str
+ Path to RWKV model file in ggml format.
+ thread_count : int
+ Thread count to use. If not set, defaults to CPU count / 2.
+ gpu_layer_count : int
+ Count of layers to offload onto the GPU, must be >= 0.
+ See documentation of `gpu_offload_layers` for details about layer offloading.
+ """
+
+ if 'gpu_layers_count' in kwargs:
+ gpu_layer_count = kwargs['gpu_layers_count']
+
+ assert os.path.isfile(model_path), f'{model_path} is not a file'
+ assert thread_count > 0, 'Thread count must be > 0'
+ assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
+
+ self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
+
+ self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
+
+ if gpu_layer_count > 0:
+ self.gpu_offload_layers(gpu_layer_count)
+
+ self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
+ self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
+
+ self._valid: bool = True
+
+ def gpu_offload_layers(self, layer_count: int) -> bool:
+ """
+ Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
+ For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
+ - pass `model.n_layer` to offload all layers except model head
+ - pass `model.n_layer + 1` to offload all layers, including model head
+
+ Returns true if at least one layer was offloaded.
+ If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
+
+ Parameters
+ ----------
+ layer_count : int
+ Count of layers to offload onto the GPU, must be >= 0.
+ """
+
+ assert layer_count >= 0, 'Layer count must be >= 0'
+
+ return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
+
+ @property
+ def n_vocab(self) -> int:
+ return self._library.rwkv_get_n_vocab(self._ctx)
+
+ @property
+ def n_embed(self) -> int:
+ return self._library.rwkv_get_n_embed(self._ctx)
+
+ @property
+ def n_layer(self) -> int:
+ return self._library.rwkv_get_n_layer(self._ctx)
+
+ def eval(
+ self,
+ token: int,
+ state_in: Optional[NumpyArrayOrPyTorchTensor],
+ state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
+ logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
+ use_numpy: bool = False
+ ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
+ """
+ Evaluates the model for a single token.
+ In case of any error, this method will throw an exception.
+
+ Parameters
+ ----------
+ token : int
+ Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
+ state_in : Optional[NumpyArrayOrTorchTensor]
+ State from previous call of this method. If this is a first pass, set it to None.
+ state_out : Optional[NumpyArrayOrTorchTensor]
+ Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
+ logits_out : Optional[NumpyArrayOrTorchTensor]
+ Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
+ use_numpy : bool
+ If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
+ This parameter is ignored if any tensor parameter is not None; in such case,
+ type of returned tensors will match the type of received tensors.
+
+ Returns
+ -------
+ logits, state
+ Logits vector of shape (n_vocab); state for the next step.
+ """
+
+ assert self._valid, 'Model was freed'
+
+ use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
+
+ if state_in is not None:
+ self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
+
+ state_in_ptr = self._get_data_ptr(state_in)
+ else:
+ state_in_ptr = 0
+
+ if state_out is not None:
+ self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
+ else:
+ state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
+
+ if logits_out is not None:
+ self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
+ else:
+ logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
+
+ self._library.rwkv_eval(
+ self._ctx,
+ token,
+ state_in_ptr,
+ self._get_data_ptr(state_out),
+ self._get_data_ptr(logits_out)
+ )
+
+ return logits_out, state_out
+
+ def eval_sequence(
+ self,
+ tokens: List[int],
+ state_in: Optional[NumpyArrayOrPyTorchTensor],
+ state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
+ logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
+ use_numpy: bool = False
+ ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
+ """
+ Evaluates the model for a sequence of tokens.
+
+ NOTE ON GGML NODE LIMIT
+
+ ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
+ this limit when using large models and/or large sequence lengths.
+ Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
+
+ If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
+ To get rid of the assertion failure, reduce the model size and/or sequence length.
+
+ In case of any error, this method will throw an exception.
+
+ Parameters
+ ----------
+ tokens : List[int]
+ Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
+ state_in : Optional[NumpyArrayOrTorchTensor]
+ State from previous call of this method. If this is a first pass, set it to None.
+ state_out : Optional[NumpyArrayOrTorchTensor]
+ Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
+ logits_out : Optional[NumpyArrayOrTorchTensor]
+ Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
+ use_numpy : bool
+ If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
+ This parameter is ignored if any tensor parameter is not None; in such case,
+ type of returned tensors will match the type of received tensors.
+
+ Returns
+ -------
+ logits, state
+ Logits vector of shape (n_vocab); state for the next step.
+ """
+
+ assert self._valid, 'Model was freed'
+
+ use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
+
+ if state_in is not None:
+ self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
+
+ state_in_ptr = self._get_data_ptr(state_in)
+ else:
+ state_in_ptr = 0
+
+ if state_out is not None:
+ self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
+ else:
+ state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
+
+ if logits_out is not None:
+ self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
+ else:
+ logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
+
+ self._library.rwkv_eval_sequence(
+ self._ctx,
+ tokens,
+ state_in_ptr,
+ self._get_data_ptr(state_out),
+ self._get_data_ptr(logits_out)
+ )
+
+ return logits_out, state_out
+
+ def eval_sequence_in_chunks(
+ self,
+ tokens: List[int],
+ state_in: Optional[NumpyArrayOrPyTorchTensor],
+ state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
+ logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
+ chunk_size: int = 16,
+ use_numpy: bool = False
+ ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
+ """
+ Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
+ This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
+ It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
+
+ Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
+ A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
+ and choose one that works the best in your use case.
+
+ In case of any error, this method will throw an exception.
+
+ Parameters
+ ----------
+ tokens : List[int]
+ Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
+ chunk_size : int
+ Size of each chunk in tokens, must be positive.
+ state_in : Optional[NumpyArrayOrTorchTensor]
+ State from previous call of this method. If this is a first pass, set it to None.
+ state_out : Optional[NumpyArrayOrTorchTensor]
+ Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
+ logits_out : Optional[NumpyArrayOrTorchTensor]
+ Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
+ use_numpy : bool
+ If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
+ This parameter is ignored if any tensor parameter is not None; in such case,
+ type of returned tensors will match the type of received tensors.
+
+ Returns
+ -------
+ logits, state
+ Logits vector of shape (n_vocab); state for the next step.
+ """
+
+ assert self._valid, 'Model was freed'
+
+ use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
+
+ if state_in is not None:
+ self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
+
+ state_in_ptr = self._get_data_ptr(state_in)
+ else:
+ state_in_ptr = 0
+
+ if state_out is not None:
+ self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
+ else:
+ state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
+
+ if logits_out is not None:
+ self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
+ else:
+ logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
+
+ self._library.rwkv_eval_sequence_in_chunks(
+ self._ctx,
+ tokens,
+ chunk_size,
+ state_in_ptr,
+ self._get_data_ptr(state_out),
+ self._get_data_ptr(logits_out)
+ )
+
+ return logits_out, state_out
+
+ def free(self) -> None:
+ """
+ Frees all allocated resources.
+ In case of any error, this method will throw an exception.
+ The object must not be used anymore after calling this method.
+ """
+
+ assert self._valid, 'Already freed'
+
+ self._valid = False
+
+ self._library.rwkv_free(self._ctx)
+
+ def __del__(self) -> None:
+ # Free the context on GC in case user forgot to call free() explicitly.
+ if hasattr(self, '_valid') and self._valid:
+ self.free()
+
+ def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
+ return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
+
+ def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
+ for tensor in tensors:
+ if tensor is not None:
+ return False if self._is_pytorch_tensor(tensor) else True
+
+ return use_numpy_by_default
+
+ def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
+ if self._is_pytorch_tensor(tensor):
+ tensor: torch.Tensor = tensor
+ assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
+ assert tensor.dtype == torch.float32, f'{name} is not of type float32'
+ assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
+ assert tensor.is_contiguous(), f'{name} is not contiguous'
+ else:
+ import numpy as np
+ tensor: np.ndarray = tensor
+ assert tensor.dtype == np.float32, f'{name} is not of type float32'
+ assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
+ assert tensor.data.contiguous, f'{name} is not contiguous'
+
+ def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
+ if self._is_pytorch_tensor(tensor):
+ return tensor.data_ptr()
+ else:
+ return tensor.ctypes.data
+
+ def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
+ if use_numpy:
+ import numpy as np
+ return np.zeros(element_count, dtype=np.float32)
+ else:
+ return torch.zeros(element_count, dtype=torch.float32, device='cpu')
diff --git a/backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py b/backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py
new file mode 100644
index 0000000..60f387f
--- /dev/null
+++ b/backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py
@@ -0,0 +1,444 @@
+import os
+import sys
+import ctypes
+import pathlib
+import platform
+from typing import Optional, List, Tuple, Callable
+
+QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
+ 'Q4_0',
+ 'Q4_1',
+ 'Q5_0',
+ 'Q5_1',
+ 'Q8_0'
+)
+
+P_FLOAT = ctypes.POINTER(ctypes.c_float)
+P_INT = ctypes.POINTER(ctypes.c_int32)
+
+class RWKVContext:
+
+ def __init__(self, ptr: ctypes.pointer) -> None:
+ self.ptr: ctypes.pointer = ptr
+
+class RWKVSharedLibrary:
+ """
+ Python wrapper around rwkv.cpp shared library.
+ """
+
+ def __init__(self, shared_library_path: str) -> None:
+ """
+ Loads the shared library from specified file.
+ In case of any error, this method will throw an exception.
+
+ Parameters
+ ----------
+ shared_library_path : str
+ Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
+ """
+ # When Python is greater than 3.8, we need to reprocess the custom dll
+ # according to the documentation to prevent loading failure errors.
+ # https://docs.python.org/3/whatsnew/3.8.html#ctypes
+ if platform.system().lower() == 'windows':
+ self.library = ctypes.CDLL(shared_library_path, winmode=0)
+ else:
+ self.library = ctypes.cdll.LoadLibrary(shared_library_path)
+
+ self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
+ self.library.rwkv_init_from_file.restype = ctypes.c_void_p
+
+ self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
+ self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
+
+ self.library.rwkv_eval.argtypes = [
+ ctypes.c_void_p, # ctx
+ ctypes.c_int32, # token
+ P_FLOAT, # state_in
+ P_FLOAT, # state_out
+ P_FLOAT # logits_out
+ ]
+ self.library.rwkv_eval.restype = ctypes.c_bool
+
+ self.library.rwkv_eval_sequence.argtypes = [
+ ctypes.c_void_p, # ctx
+ P_INT, # tokens
+ ctypes.c_size_t, # token count
+ P_FLOAT, # state_in
+ P_FLOAT, # state_out
+ P_FLOAT # logits_out
+ ]
+ self.library.rwkv_eval_sequence.restype = ctypes.c_bool
+
+ self.library.rwkv_eval_sequence_in_chunks.argtypes = [
+ ctypes.c_void_p, # ctx
+ P_INT, # tokens
+ ctypes.c_size_t, # token count
+ ctypes.c_size_t, # chunk size
+ P_FLOAT, # state_in
+ P_FLOAT, # state_out
+ P_FLOAT # logits_out
+ ]
+ self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
+
+ self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
+
+ self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_get_n_embed.restype = ctypes.c_size_t
+
+ self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_get_n_layer.restype = ctypes.c_size_t
+
+ self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
+
+ self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
+
+ self.library.rwkv_free.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_free.restype = None
+
+ self.library.rwkv_free.argtypes = [ctypes.c_void_p]
+ self.library.rwkv_free.restype = None
+
+ self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
+ self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
+
+ self.library.rwkv_get_system_info_string.argtypes = []
+ self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
+
+ self.nullptr = ctypes.cast(0, ctypes.c_void_p)
+
+ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
+ """
+ Loads the model from a file and prepares it for inference.
+ Throws an exception in case of any error. Error messages would be printed to stderr.
+
+ Parameters
+ ----------
+ model_file_path : str
+ Path to model file in ggml format.
+ thread_count : int
+ Count of threads to use, must be positive.
+ """
+
+ ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
+
+ assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
+
+ return RWKVContext(ptr)
+
+ def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
+ """
+ Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
+ For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
+ - pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
+ - pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
+ Returns true if at least one layer was offloaded.
+ If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ layer_count : int
+ Count of layers to offload onto the GPU, must be >= 0.
+ """
+
+ assert layer_count >= 0, 'Layer count must be >= 0'
+
+ return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
+
+ def rwkv_eval(
+ self,
+ ctx: RWKVContext,
+ token: int,
+ state_in_address: Optional[int],
+ state_out_address: int,
+ logits_out_address: int
+ ) -> None:
+ """
+ Evaluates the model for a single token.
+ Throws an exception in case of any error. Error messages would be printed to stderr.
+ Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ token : int
+ Next token index, in range 0 <= token < n_vocab.
+ state_in_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
+ state_out_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
+ logits_out_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
+ """
+
+ assert self.library.rwkv_eval(
+ ctx.ptr,
+ ctypes.c_int32(token),
+ ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
+ ctypes.cast(state_out_address, P_FLOAT),
+ ctypes.cast(logits_out_address, P_FLOAT)
+ ), 'rwkv_eval failed, check stderr'
+
+ def rwkv_eval_sequence(
+ self,
+ ctx: RWKVContext,
+ tokens: List[int],
+ state_in_address: Optional[int],
+ state_out_address: int,
+ logits_out_address: int
+ ) -> None:
+ """
+ Evaluates the model for a sequence of tokens.
+ Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
+ Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
+
+ NOTE ON GGML NODE LIMIT
+
+ ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
+ this limit when using large models and/or large sequence lengths.
+ Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
+
+ If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
+ To get rid of the assertion failure, reduce the model size and/or sequence length.
+
+ Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
+ Throws an exception in case of any error. Error messages would be printed to stderr.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ tokens : List[int]
+ Next token indices, in range 0 <= token < n_vocab.
+ state_in_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
+ state_out_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
+ logits_out_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
+ """
+
+ assert self.library.rwkv_eval_sequence(
+ ctx.ptr,
+ ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
+ ctypes.c_size_t(len(tokens)),
+ ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
+ ctypes.cast(state_out_address, P_FLOAT),
+ ctypes.cast(logits_out_address, P_FLOAT)
+ ), 'rwkv_eval_sequence failed, check stderr'
+
+ def rwkv_eval_sequence_in_chunks(
+ self,
+ ctx: RWKVContext,
+ tokens: List[int],
+ chunk_size: int,
+ state_in_address: Optional[int],
+ state_out_address: int,
+ logits_out_address: int
+ ) -> None:
+ """
+ Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
+ This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
+ It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
+
+ Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
+ A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
+ and choose one that works the best in your use case.
+
+ Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
+ Throws an exception in case of any error. Error messages would be printed to stderr.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ tokens : List[int]
+ Next token indices, in range 0 <= token < n_vocab.
+ chunk_size : int
+ Size of each chunk in tokens, must be positive.
+ state_in_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
+ state_out_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
+ logits_out_address : int
+ Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
+ """
+
+ assert self.library.rwkv_eval_sequence_in_chunks(
+ ctx.ptr,
+ ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
+ ctypes.c_size_t(len(tokens)),
+ ctypes.c_size_t(chunk_size),
+ ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
+ ctypes.cast(state_out_address, P_FLOAT),
+ ctypes.cast(logits_out_address, P_FLOAT)
+ ), 'rwkv_eval_sequence_in_chunks failed, check stderr'
+
+ def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
+ """
+ Returns the number of tokens in the given model's vocabulary.
+ Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ """
+
+ return self.library.rwkv_get_n_vocab(ctx.ptr)
+
+ def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
+ """
+ Returns the number of elements in the given model's embedding.
+ Useful for reading individual fields of a model's hidden state.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ """
+
+ return self.library.rwkv_get_n_embed(ctx.ptr)
+
+ def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
+ """
+ Returns the number of layers in the given model.
+ A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
+ Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
+ Useful for always offloading the entire model to GPU.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ """
+
+ return self.library.rwkv_get_n_layer(ctx.ptr)
+
+ def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
+ """
+ Returns count of FP32 elements in state buffer.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ """
+
+ return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
+
+ def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
+ """
+ Returns count of FP32 elements in logits buffer.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ """
+
+ return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
+
+ def rwkv_free(self, ctx: RWKVContext) -> None:
+ """
+ Frees all allocated memory and the context.
+
+ Parameters
+ ----------
+ ctx : RWKVContext
+ RWKV context obtained from rwkv_init_from_file.
+ """
+
+ self.library.rwkv_free(ctx.ptr)
+
+ ctx.ptr = self.nullptr
+
+ def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
+ """
+ Quantizes FP32 or FP16 model to one of INT4 formats.
+ Throws an exception in case of any error. Error messages would be printed to stderr.
+
+ Parameters
+ ----------
+ model_file_path_in : str
+ Path to model file in ggml format, must be either FP32 or FP16.
+ model_file_path_out : str
+ Quantized model will be written here.
+ format_name : str
+ One of QUANTIZED_FORMAT_NAMES.
+ """
+
+ assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
+
+ assert self.library.rwkv_quantize_model_file(
+ model_file_path_in.encode('utf-8'),
+ model_file_path_out.encode('utf-8'),
+ format_name.encode('utf-8')
+ ), 'rwkv_quantize_model_file failed, check stderr'
+
+ def rwkv_get_system_info_string(self) -> str:
+ """
+ Returns system information string.
+ """
+
+ return self.library.rwkv_get_system_info_string().decode('utf-8')
+
+def load_rwkv_shared_library() -> RWKVSharedLibrary:
+ """
+ Attempts to find rwkv.cpp shared library and load it.
+ To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
+ """
+
+ file_name: str
+
+ if 'win32' in sys.platform or 'cygwin' in sys.platform:
+ file_name = 'rwkv.dll'
+ elif 'darwin' in sys.platform:
+ file_name = 'librwkv.dylib'
+ else:
+ file_name = 'librwkv.so'
+
+ # Possible sub-paths to the library relative to the repo dir.
+ child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
+ # No lookup for Debug config here.
+ # I assume that if a user wants to debug the library,
+ # they will be able to find the library and set the exact path explicitly.
+ lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name,
+ lambda p: p / 'bin' / 'Release' / file_name,
+ lambda p: p / 'bin' / file_name,
+ # Some people prefer to build in the "build" subdirectory.
+ lambda p: p / 'build' / 'bin' / 'Release' / file_name,
+ lambda p: p / 'build' / 'bin' / file_name,
+ lambda p: p / 'build' / file_name,
+ # Fallback.
+ lambda p: p / file_name
+ ]
+
+ working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
+
+ parent_paths: List[pathlib.Path] = [
+ # Possible repo dirs relative to the working dir.
+ # ./python/rwkv_cpp
+ working_dir.parent.parent,
+ # ./python
+ working_dir.parent,
+ # .
+ working_dir,
+ # Repo dir relative to this Python file.
+ pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
+ ]
+
+ for parent_path in parent_paths:
+ for child_path in child_paths:
+ full_path: pathlib.Path = child_path(parent_path)
+
+ if os.path.isfile(full_path):
+ return RWKVSharedLibrary(str(full_path))
+
+ assert False, (f'Failed to find {file_name} automatically; '
+ f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py
index b09f230..a0f8ea4 100644
--- a/backend-python/rwkv_pip/utils.py
+++ b/backend-python/rwkv_pip/utils.py
@@ -78,12 +78,22 @@ class PIPELINE:
def decode(self, x):
return self.tokenizer.decode(x)
+ def np_softmax(self, x: np.ndarray, axis: int):
+ x -= x.max(axis=axis, keepdims=True)
+ e: np.ndarray = np.exp(x)
+ return e / e.sum(axis=axis, keepdims=True)
+
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
- probs = F.softmax(logits.float(), dim=-1)
+ np_logits = type(logits) == np.ndarray
+ if np_logits:
+ probs = self.np_softmax(logits, axis=-1)
+ else:
+ probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
- if probs.device.type in ["cpu", "privateuseone"]:
- probs = probs.cpu().numpy()
+ if np_logits or probs.device.type in ["cpu", "privateuseone"]:
+ if not np_logits:
+ probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py
index d6ca77e..ed60e81 100644
--- a/backend-python/utils/rwkv.py
+++ b/backend-python/utils/rwkv.py
@@ -510,15 +510,22 @@ def get_tokenizer(tokenizer_len: int):
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
+ rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
if "midi" in model.lower() or "abc" in model.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
+ print("Using rwkv-beta")
from rwkv_pip.beta.model import (
RWKV as Model,
)
+ elif rwkv_cpp:
+ print("Using rwkv.cpp, strategy is ignored")
+ from rwkv_pip.cpp.model import (
+ RWKV as Model,
+ )
else:
from rwkv_pip.model import (
RWKV as Model,
diff --git a/frontend/src/_locales/ja/main.json b/frontend/src/_locales/ja/main.json
index 326c02b..8efdec8 100644
--- a/frontend/src/_locales/ja/main.json
+++ b/frontend/src/_locales/ja/main.json
@@ -128,7 +128,7 @@
"Chinese Kongfu": "中国武術",
"Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)",
"Custom": "カスタム",
- "CUDA (Beta, Faster)": "CUDA (ベータ、高速)",
+ "CUDA (Beta, Faster)": "CUDA (Beta, 高速)",
"Reset All Configs": "すべての設定をリセット",
"Cancel": "キャンセル",
"Confirm": "確認",
@@ -313,5 +313,8 @@
"Music": "音楽",
"Other": "その他",
"Import MIDI": "MIDIをインポート",
- "Current Instrument": "現在の楽器"
+ "Current Instrument": "現在の楽器",
+ "Please convert model to GGML format first": "モデルをGGML形式に変換してください",
+ "Convert To GGML Format": "GGML形式に変換",
+ "CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 高速)"
}
\ No newline at end of file
diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json
index f39640b..1862911 100644
--- a/frontend/src/_locales/zh-hans/main.json
+++ b/frontend/src/_locales/zh-hans/main.json
@@ -313,5 +313,8 @@
"Music": "音乐",
"Other": "其他",
"Import MIDI": "导入MIDI",
- "Current Instrument": "当前乐器"
+ "Current Instrument": "当前乐器",
+ "Please convert model to GGML format first": "请先将模型转换为GGML格式",
+ "Convert To GGML Format": "转换为GGML格式",
+ "CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 更快)"
}
\ No newline at end of file
diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx
index 4e5867d..8cd5adf 100644
--- a/frontend/src/components/RunButton.tsx
+++ b/frontend/src/components/RunButton.tsx
@@ -17,7 +17,8 @@ import { ToolTipButton } from './ToolTipButton';
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
import { useNavigate } from 'react-router';
import { WindowShow } from '../../wailsjs/runtime';
-import { convertToSt } from '../utils/convert-to-st';
+import { convertToGGML, convertToSt } from '../utils/convert-model';
+import { Precision } from '../types/configs';
const mainButtonText = {
[ModelStatus.Offline]: 'Run',
@@ -47,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const modelConfig = commonStore.getCurrentModelConfig();
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
+ const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)';
let modelName = '';
let modelPath = '';
if (modelConfig && modelConfig.modelParameters) {
@@ -112,6 +114,30 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
+ if (cpp) {
+ if (!['.bin'].some(ext => modelPath.endsWith(ext))) {
+ const precision: Precision = modelConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
+ const ggmlModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
+ if (await FileExists(ggmlModelPath)) {
+ modelPath = ggmlModelPath;
+ } else if (!await FileExists(modelPath)) {
+ showDownloadPrompt(t('Model file not found'), modelName);
+ commonStore.setStatus({ status: ModelStatus.Offline });
+ return;
+ } else if (!currentModelSource?.isComplete) {
+ showDownloadPrompt(t('Model file download is not complete'), modelName);
+ commonStore.setStatus({ status: ModelStatus.Offline });
+ return;
+ } else {
+ toastWithButton(t('Please convert model to GGML format first'), t('Convert'), () => {
+ convertToGGML(modelConfig, navigate);
+ });
+ commonStore.setStatus({ status: ModelStatus.Offline });
+ return;
+ }
+ }
+ }
+
if (!await FileExists(modelPath)) {
showDownloadPrompt(t('Model file not found'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
@@ -142,7 +168,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta';
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
- !!modelConfig.enableWebUI, isUsingCudaBeta
+ !!modelConfig.enableWebUI, isUsingCudaBeta, cpp
).catch((e) => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx
index 0a787dd..d9bdc76 100644
--- a/frontend/src/pages/Configs.tsx
+++ b/frontend/src/pages/Configs.tsx
@@ -27,18 +27,19 @@ import { Page } from '../components/Page';
import { useNavigate } from 'react-router';
import { RunButton } from '../components/RunButton';
import { updateConfig } from '../apis';
-import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
-import { checkDependencies, getStrategy } from '../utils';
+import { getStrategy } from '../utils';
import { useTranslation } from 'react-i18next';
-import { WindowShow } from '../../wailsjs/runtime';
import strategyImg from '../assets/images/strategy.jpg';
import strategyZhImg from '../assets/images/strategy_zh.jpg';
import { ResetConfigsButton } from '../components/ResetConfigsButton';
import { useMediaQuery } from 'usehooks-ts';
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
-import { convertToSt } from '../utils/convert-to-st';
+import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
-const ConfigSelector: FC<{ selectedIndex: number, updateSelectedIndex: (i: number) => void }> = observer(({ selectedIndex, updateSelectedIndex }) => {
+const ConfigSelector: FC<{
+ selectedIndex: number,
+ updateSelectedIndex: (i: number) => void
+}> = observer(({ selectedIndex, updateSelectedIndex }) => {
return (
{
} />
{
selectedConfig.modelParameters.device !== 'WebGPU' ?
- {
- if (commonStore.platform === 'darwin') {
- toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
- return;
- } else if (commonStore.platform === 'linux') {
- toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
- return;
- }
-
- const ok = await checkDependencies(navigate);
- if (!ok)
- return;
-
- const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
- if (await FileExists(modelPath)) {
- const strategy = getStrategy(selectedConfig);
- const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
- toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
- ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
- if (!await FileExists(newModelPath + '.pth')) {
- toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
- } else {
- toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
- }
- }).catch(e => {
- const errMsg = e.message || e;
- if (errMsg.includes('path contains space'))
- toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
- else
- toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
- });
- setTimeout(WindowShow, 1000);
- } else {
- toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
- }
- }} /> :
- convertModel(selectedConfig, navigate)} /> :
+ convertToGGML(selectedConfig, navigate)} />)
+ : convertToSt(selectedConfig)} />
}
@@ -299,6 +269,7 @@ const Configs: FC = observer(() => {
}
}}>
+
{commonStore.platform === 'darwin' && }
@@ -322,9 +293,11 @@ const Configs: FC = observer(() => {
}}>
{selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
}
-
+ {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && }
{selectedConfig.modelParameters.device === 'WebGPU' && }
- {selectedConfig.modelParameters.device !== 'WebGPU' && }
+ {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && selectedConfig.modelParameters.device !== 'WebGPU' &&
+ }
+ {selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && }
} />
}
diff --git a/frontend/src/types/configs.ts b/frontend/src/types/configs.ts
index 0814f66..ff156e4 100644
--- a/frontend/src/types/configs.ts
+++ b/frontend/src/types/configs.ts
@@ -6,8 +6,8 @@ export type ApiParameters = {
presencePenalty: number;
frequencyPenalty: number;
}
-export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
-export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4';
+export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
+export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
export type ModelParameters = {
// different models can not have the same name
modelName: string;
diff --git a/frontend/src/utils/convert-model.ts b/frontend/src/utils/convert-model.ts
new file mode 100644
index 0000000..6e9519f
--- /dev/null
+++ b/frontend/src/utils/convert-model.ts
@@ -0,0 +1,107 @@
+import { toast } from 'react-toastify';
+import commonStore from '../stores/commonStore';
+import { t } from 'i18next';
+import {
+ ConvertGGML,
+ ConvertModel,
+ ConvertSafetensors,
+ FileExists,
+ GetPyError
+} from '../../wailsjs/go/backend_golang/App';
+import { WindowShow } from '../../wailsjs/runtime';
+import { ModelConfig, Precision } from '../types/configs';
+import { checkDependencies, getStrategy } from './index';
+import { NavigateFunction } from 'react-router';
+
+export const convertModel = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
+ if (commonStore.platform === 'darwin') {
+ toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
+ return;
+ } else if (commonStore.platform === 'linux') {
+ toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
+ return;
+ }
+
+ const ok = await checkDependencies(navigate);
+ if (!ok)
+ return;
+
+ const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
+ if (await FileExists(modelPath)) {
+ const strategy = getStrategy(selectedConfig);
+ const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
+ toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
+ ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
+ if (!await FileExists(newModelPath + '.pth')) {
+ toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
+ } else {
+ toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
+ }
+ }).catch(e => {
+ const errMsg = e.message || e;
+ if (errMsg.includes('path contains space'))
+ toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
+ else
+ toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
+ });
+ setTimeout(WindowShow, 1000);
+ } else {
+ toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
+ }
+};
+
+
+export const convertToSt = async (selectedConfig: ModelConfig) => {
+ const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
+ if (await FileExists(modelPath)) {
+ toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
+ const newModelPath = modelPath.replace(/\.pth$/, '.st');
+ ConvertSafetensors(modelPath, newModelPath).then(async () => {
+ if (!await FileExists(newModelPath)) {
+ if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
+ toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
+ } else {
+ toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
+ }
+ }).catch(e => {
+ const errMsg = e.message || e;
+ if (errMsg.includes('path contains space'))
+ toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
+ else
+ toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
+ });
+ setTimeout(WindowShow, 1000);
+ } else {
+ toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
+ }
+};
+
+export const convertToGGML = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
+ const ok = await checkDependencies(navigate);
+ if (!ok)
+ return;
+
+ const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
+ if (await FileExists(modelPath)) {
+ toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
+ const precision: Precision = selectedConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
+ const newModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
+ ConvertGGML(commonStore.settings.customPythonPath, modelPath, newModelPath, precision === 'Q5_1').then(async () => {
+ if (!await FileExists(newModelPath)) {
+ if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
+ toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
+ } else {
+ toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
+ }
+ }).catch(e => {
+ const errMsg = e.message || e;
+ if (errMsg.includes('path contains space'))
+ toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
+ else
+ toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
+ });
+ setTimeout(WindowShow, 1000);
+ } else {
+ toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
+ }
+};
\ No newline at end of file
diff --git a/frontend/src/utils/convert-to-st.ts b/frontend/src/utils/convert-to-st.ts
deleted file mode 100644
index 32cf68b..0000000
--- a/frontend/src/utils/convert-to-st.ts
+++ /dev/null
@@ -1,31 +0,0 @@
-import { toast } from 'react-toastify';
-import commonStore from '../stores/commonStore';
-import { t } from 'i18next';
-import { ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
-import { WindowShow } from '../../wailsjs/runtime';
-import { ModelConfig } from '../types/configs';
-
-export const convertToSt = async (selectedConfig: ModelConfig) => {
- const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
- if (await FileExists(modelPath)) {
- toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
- const newModelPath = modelPath.replace(/\.pth$/, '.st');
- ConvertSafetensors(modelPath, newModelPath).then(async () => {
- if (!await FileExists(newModelPath)) {
- if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
- toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
- } else {
- toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
- }
- }).catch(e => {
- const errMsg = e.message || e;
- if (errMsg.includes('path contains space'))
- toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
- else
- toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
- });
- setTimeout(WindowShow, 1000);
- } else {
- toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
- }
-};
\ No newline at end of file
diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx
index 7869cfe..8ae72eb 100644
--- a/frontend/src/utils/index.tsx
+++ b/frontend/src/utils/index.tsx
@@ -63,7 +63,7 @@ export async function refreshBuiltInModels(readCache: boolean = false) {
return cache;
}
-const modelSuffix = ['.pth', '.st', '.safetensors'];
+const modelSuffix = ['.pth', '.st', '.safetensors', '.bin'];
export async function refreshLocalModels(cache: {
models: ModelSourceItem[]
diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts
index de1ff57..10e6503 100755
--- a/frontend/wailsjs/go/backend_golang/App.d.ts
+++ b/frontend/wailsjs/go/backend_golang/App.d.ts
@@ -10,6 +10,8 @@ export function ContinueDownload(arg1:string):Promise;
export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Promise;
+export function ConvertGGML(arg1:string,arg2:string,arg3:string,arg4:boolean):Promise;
+
export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise;
export function ConvertSafetensors(arg1:string,arg2:string):Promise;
@@ -58,7 +60,7 @@ export function RestartApp():Promise;
export function SaveJson(arg1:string,arg2:any):Promise;
-export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean):Promise;
+export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean):Promise;
export function StartWebGPUServer(arg1:number,arg2:string):Promise;
diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js
index e5d93c5..8d16582 100755
--- a/frontend/wailsjs/go/backend_golang/App.js
+++ b/frontend/wailsjs/go/backend_golang/App.js
@@ -18,6 +18,10 @@ export function ConvertData(arg1, arg2, arg3, arg4) {
return window['go']['backend_golang']['App']['ConvertData'](arg1, arg2, arg3, arg4);
}
+export function ConvertGGML(arg1, arg2, arg3, arg4) {
+ return window['go']['backend_golang']['App']['ConvertGGML'](arg1, arg2, arg3, arg4);
+}
+
export function ConvertModel(arg1, arg2, arg3, arg4) {
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4);
}
@@ -114,8 +118,8 @@ export function SaveJson(arg1, arg2) {
return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2);
}
-export function StartServer(arg1, arg2, arg3, arg4, arg5) {
- return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5);
+export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6) {
+ return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6);
}
export function StartWebGPUServer(arg1, arg2) {
diff --git a/vendor.yml b/vendor.yml
index bc47abd..8b4553c 100644
--- a/vendor.yml
+++ b/vendor.yml
@@ -3,6 +3,7 @@
- ^backend-python/get-pip\.py
- ^backend-python/convert_model\.py
- ^backend-python/convert_safetensors\.py
+- ^backend-python/convert_pytorch_to_ggml\.py linguist-vendored
- ^backend-python/utils/midi\.py
- ^build/
- ^finetune/lora/