diff --git a/.gitattributes b/.gitattributes index 5493593..965add0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,6 +3,7 @@ backend-python/wkv_cuda_utils/** linguist-vendored backend-python/get-pip.py linguist-vendored backend-python/convert_model.py linguist-vendored backend-python/convert_safetensors.py linguist-vendored +backend-python/convert_pytorch_to_ggml.py linguist-vendored backend-python/utils/midi.py linguist-vendored build/** linguist-vendored finetune/lora/** linguist-vendored diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a8739fa..778f021 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -65,6 +65,8 @@ jobs: Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../libs" -Destination "py310/libs" -Recurse ./py310/python -m pip install cyac==1.9 go install github.com/wailsapp/wails/v2/cmd/wails@latest + del ./backend-python/rwkv_pip/cpp/librwkv.dylib + del ./backend-python/rwkv_pip/cpp/librwkv.so (Get-Content -Path ./backend-golang/app.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/app.go make Rename-Item -Path "build/bin/RWKV-Runner.exe" -NewName "RWKV-Runner_windows_x64.exe" @@ -93,6 +95,8 @@ jobs: rm ./backend-python/rwkv_pip/rwkv6.pyd rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd rm ./backend-python/get-pip.py + rm ./backend-python/rwkv_pip/cpp/librwkv.dylib + rm ./backend-python/rwkv_pip/cpp/rwkv.dll make mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64 @@ -117,6 +121,8 @@ jobs: rm ./backend-python/rwkv_pip/rwkv6.pyd rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd rm ./backend-python/get-pip.py + rm ./backend-python/rwkv_pip/cpp/rwkv.dll + rm ./backend-python/rwkv_pip/cpp/librwkv.so make cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index 482bd9a..a6c5d8f 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -10,7 +10,7 @@ import ( "strings" ) -func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool) (string, error) { +func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool) (string, error) { var err error if python == "" { python, err = GetPython() @@ -25,6 +25,9 @@ func (a *App) StartServer(python string, port int, host string, webui bool, rwkv if rwkvBeta { args = append(args, "--rwkv-beta") } + if rwkvcpp { + args = append(args, "--rwkv.cpp") + } args = append(args, "--port", strconv.Itoa(port), "--host", host) return Cmd(args...) } @@ -52,6 +55,21 @@ func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, erro return Cmd(args...) } +func (a *App) ConvertGGML(python string, modelPath string, outPath string, Q51 bool) (string, error) { + var err error + if python == "" { + python, err = GetPython() + } + if err != nil { + return "", err + } + dataType := "FP16" + if Q51 { + dataType = "Q5_1" + } + return Cmd(python, "./backend-python/convert_pytorch_to_ggml.py", modelPath, outPath, dataType) +} + func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) { var err error if python == "" { diff --git a/backend-python/convert_pytorch_to_ggml.py b/backend-python/convert_pytorch_to_ggml.py new file mode 100644 index 0000000..6c42e35 --- /dev/null +++ b/backend-python/convert_pytorch_to_ggml.py @@ -0,0 +1,169 @@ +# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file. +# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16 +# Get model checkpoints from https://huggingface.co/BlinkDL +# See FILE_FORMAT.md for the documentation on the file format. + +import argparse +import struct +import torch +from typing import Dict + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file" + ) + parser.add_argument("src_path", help="Path to PyTorch checkpoint file") + parser.add_argument( + "dest_path", help="Path to rwkv.cpp checkpoint file, will be overwritten" + ) + parser.add_argument( + "data_type", + help="Data type, FP16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0", + type=str, + choices=[ + "FP16", + "Q4_0", + "Q4_1", + "Q5_0", + "Q5_1", + "Q8_0", + ], + default="FP16", + ) + return parser.parse_args() + + +def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: + n_layer: int = 0 + + while f"blocks.{n_layer}.ln1.weight" in state_dict: + n_layer += 1 + + assert n_layer > 0 + + return n_layer + + +def write_state_dict( + state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str +) -> None: + emb_weight: torch.Tensor = state_dict["emb.weight"] + + n_layer: int = get_layer_count(state_dict) + n_vocab: int = emb_weight.shape[0] + n_embed: int = emb_weight.shape[1] + + is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict + is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict + + if is_v5_2: + print("Detected RWKV v5.2") + elif is_v5_1_or_2: + print("Detected RWKV v5.1") + else: + print("Detected RWKV v4") + + with open(dest_path, "wb") as out_file: + is_FP16: bool = data_type == "FP16" or data_type == "float16" + + out_file.write( + struct.pack( + # Disable padding with '=' + "=iiiiii", + # Magic: 'ggmf' in hex + 0x67676D66, + 101, + n_vocab, + n_embed, + n_layer, + 1 if is_FP16 else 0, + ) + ) + + for k in state_dict.keys(): + tensor: torch.Tensor = state_dict[k].float() + + if ".time_" in k: + tensor = tensor.squeeze() + + if is_v5_1_or_2: + if ".time_decay" in k: + if is_v5_2: + tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1) + else: + tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1) + + if ".time_first" in k: + tensor = torch.exp(tensor).reshape(-1, 1, 1) + + if ".time_faaaa" in k: + tensor = tensor.unsqueeze(-1) + else: + if ".time_decay" in k: + tensor = -torch.exp(tensor) + + # Keep 1-dim vectors and small matrices in FP32 + if is_FP16 and len(tensor.shape) > 1 and ".time_" not in k: + tensor = tensor.half() + + shape = tensor.shape + + print(f"Writing {k}, shape {shape}, type {tensor.dtype}") + + k_encoded: bytes = k.encode("utf-8") + + out_file.write( + struct.pack( + "=iii", + len(shape), + len(k_encoded), + 1 if tensor.dtype == torch.float16 else 0, + ) + ) + + # Dimension order is reversed here: + # * PyTorch shape is (x rows, y columns) + # * ggml shape is (y elements in a row, x elements in a column) + # Both shapes represent the same tensor. + for dim in reversed(tensor.shape): + out_file.write(struct.pack("=i", dim)) + + out_file.write(k_encoded) + + tensor.numpy().tofile(out_file) + + +def main() -> None: + args = parse_args() + + print(f"Reading {args.src_path}") + + state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location="cpu") + + temp_output: str = args.dest_path + if args.data_type.startswith("Q"): + import re + + temp_output = re.sub(r"Q[4,5,8]_[0,1]", "fp16", temp_output) + write_state_dict(state_dict, temp_output, "FP16") + if args.data_type.startswith("Q"): + import sys + import os + + sys.path.append(os.path.dirname(os.path.realpath(__file__))) + from rwkv_pip.cpp import rwkv_cpp_shared_library + + library = rwkv_cpp_shared_library.load_rwkv_shared_library() + library.rwkv_quantize_model_file(temp_output, args.dest_path, args.data_type) + + print("Done") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(e) + with open("error.txt", "w") as f: + f.write(str(e)) diff --git a/backend-python/main.py b/backend-python/main.py index d098b1b..1ef38c9 100644 --- a/backend-python/main.py +++ b/backend-python/main.py @@ -32,6 +32,11 @@ def get_args(args: Union[Sequence[str], None] = None): action="store_true", help="whether to use rwkv-beta (default: False)", ) + group.add_argument( + "--rwkv.cpp", + action="store_true", + help="whether to use rwkv.cpp (default: False)", + ) args = parser.parse_args(args) return args diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index da5bf28..93da367 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -49,19 +49,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): if body.model == "": return "success" - STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" - if not re.match(STRATEGY_REGEX, body.strategy): - raise HTTPException( - Status.HTTP_400_BAD_REQUEST, - "Invalid strategy. Please read https://pypi.org/project/rwkv/", - ) devices = set( [ x.strip().split(" ")[0].replace("cuda:0", "cuda") for x in body.strategy.split("->") ] ) - print(f"Devices: {devices}") + print(f"Strategy Devices: {devices}") # if len(devices) > 1: # state_cache.disable_state_cache() # else: diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 5d02f57..16693c4 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -90,10 +90,15 @@ def add_state(body: AddStateBody): try: id: int = trie.insert(body.prompt) - devices: List[torch.device] = [tensor.device for tensor in body.state] + devices: List[torch.device] = [ + (tensor.device if hasattr(tensor, "device") else torch.device("cpu")) + for tensor in body.state + ] dtrie[id] = { "tokens": copy.deepcopy(body.tokens), - "state": [tensor.cpu() for tensor in body.state], + "state": [tensor.cpu() for tensor in body.state] + if hasattr(body.state[0], "device") + else copy.deepcopy(body.state), "logits": copy.deepcopy(body.logits), "devices": devices, } @@ -185,7 +190,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): return { "prompt": prompt, "tokens": v["tokens"], - "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])], + "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])] + if hasattr(v["state"][0], "device") + else v["state"], "logits": v["logits"], } else: diff --git a/backend-python/rwkv_pip/cpp/librwkv.dylib b/backend-python/rwkv_pip/cpp/librwkv.dylib new file mode 100644 index 0000000..aeae5b2 Binary files /dev/null and b/backend-python/rwkv_pip/cpp/librwkv.dylib differ diff --git a/backend-python/rwkv_pip/cpp/librwkv.so b/backend-python/rwkv_pip/cpp/librwkv.so new file mode 100644 index 0000000..ed13248 Binary files /dev/null and b/backend-python/rwkv_pip/cpp/librwkv.so differ diff --git a/backend-python/rwkv_pip/cpp/model.py b/backend-python/rwkv_pip/cpp/model.py new file mode 100644 index 0000000..1a5a074 --- /dev/null +++ b/backend-python/rwkv_pip/cpp/model.py @@ -0,0 +1,14 @@ +from typing import Any, List +from . import rwkv_cpp_model +from . import rwkv_cpp_shared_library + + +class RWKV: + def __init__(self, model_path: str, strategy=None): + self.library = rwkv_cpp_shared_library.load_rwkv_shared_library() + self.model = rwkv_cpp_model.RWKVModel(self.library, model_path) + self.w = {} # fake weight + self.w["emb.weight"] = [0] * self.model.n_vocab + + def forward(self, tokens: List[int], state: Any | None): + return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True) diff --git a/backend-python/rwkv_pip/cpp/rwkv.dll b/backend-python/rwkv_pip/cpp/rwkv.dll new file mode 100644 index 0000000..b84a98c Binary files /dev/null and b/backend-python/rwkv_pip/cpp/rwkv.dll differ diff --git a/backend-python/rwkv_pip/cpp/rwkv_cpp_model.py b/backend-python/rwkv_pip/cpp/rwkv_cpp_model.py new file mode 100644 index 0000000..4b78c76 --- /dev/null +++ b/backend-python/rwkv_pip/cpp/rwkv_cpp_model.py @@ -0,0 +1,369 @@ +import os +import multiprocessing + +# Pre-import PyTorch, if available. +# This fixes "OSError: [WinError 127] The specified procedure could not be found". +try: + import torch +except ModuleNotFoundError: + pass + +# I'm sure this is not strictly correct, but let's keep this crutch for now. +try: + import rwkv_cpp_shared_library +except ModuleNotFoundError: + from . import rwkv_cpp_shared_library + +from typing import TypeVar, Optional, Tuple, List + +# A value of this type is either a numpy's ndarray or a PyTorch's Tensor. +NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor') + +class RWKVModel: + """ + An RWKV model managed by rwkv.cpp library. + """ + + def __init__( + self, + shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, + model_path: str, + thread_count: int = max(1, multiprocessing.cpu_count() // 2), + gpu_layer_count: int = 0, + **kwargs + ) -> None: + """ + Loads the model and prepares it for inference. + In case of any error, this method will throw an exception. + + Parameters + ---------- + shared_library : RWKVSharedLibrary + rwkv.cpp shared library. + model_path : str + Path to RWKV model file in ggml format. + thread_count : int + Thread count to use. If not set, defaults to CPU count / 2. + gpu_layer_count : int + Count of layers to offload onto the GPU, must be >= 0. + See documentation of `gpu_offload_layers` for details about layer offloading. + """ + + if 'gpu_layers_count' in kwargs: + gpu_layer_count = kwargs['gpu_layers_count'] + + assert os.path.isfile(model_path), f'{model_path} is not a file' + assert thread_count > 0, 'Thread count must be > 0' + assert gpu_layer_count >= 0, 'GPU layer count must be >= 0' + + self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library + + self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count) + + if gpu_layer_count > 0: + self.gpu_offload_layers(gpu_layer_count) + + self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx) + self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx) + + self._valid: bool = True + + def gpu_offload_layers(self, layer_count: int) -> bool: + """ + Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast. + For the purposes of this function, model head (unembedding matrix) is treated as an additional layer: + - pass `model.n_layer` to offload all layers except model head + - pass `model.n_layer + 1` to offload all layers, including model head + + Returns true if at least one layer was offloaded. + If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false. + + Parameters + ---------- + layer_count : int + Count of layers to offload onto the GPU, must be >= 0. + """ + + assert layer_count >= 0, 'Layer count must be >= 0' + + return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count) + + @property + def n_vocab(self) -> int: + return self._library.rwkv_get_n_vocab(self._ctx) + + @property + def n_embed(self) -> int: + return self._library.rwkv_get_n_embed(self._ctx) + + @property + def n_layer(self) -> int: + return self._library.rwkv_get_n_layer(self._ctx) + + def eval( + self, + token: int, + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: + """ + Evaluates the model for a single token. + In case of any error, this method will throw an exception. + + Parameters + ---------- + token : int + Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. + state_in : Optional[NumpyArrayOrTorchTensor] + State from previous call of this method. If this is a first pass, set it to None. + state_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). + logits_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self._valid, 'Model was freed' + + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + + if state_in is not None: + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + + state_in_ptr = self._get_data_ptr(state_in) + else: + state_in_ptr = 0 + + if state_out is not None: + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + else: + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) + + if logits_out is not None: + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + else: + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) + + self._library.rwkv_eval( + self._ctx, + token, + state_in_ptr, + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) + ) + + return logits_out, state_out + + def eval_sequence( + self, + tokens: List[int], + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: + """ + Evaluates the model for a sequence of tokens. + + NOTE ON GGML NODE LIMIT + + ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes + this limit when using large models and/or large sequence lengths. + Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. + + If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + To get rid of the assertion failure, reduce the model size and/or sequence length. + + In case of any error, this method will throw an exception. + + Parameters + ---------- + tokens : List[int] + Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. + state_in : Optional[NumpyArrayOrTorchTensor] + State from previous call of this method. If this is a first pass, set it to None. + state_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). + logits_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self._valid, 'Model was freed' + + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + + if state_in is not None: + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + + state_in_ptr = self._get_data_ptr(state_in) + else: + state_in_ptr = 0 + + if state_out is not None: + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + else: + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) + + if logits_out is not None: + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + else: + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) + + self._library.rwkv_eval_sequence( + self._ctx, + tokens, + state_in_ptr, + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) + ) + + return logits_out, state_out + + def eval_sequence_in_chunks( + self, + tokens: List[int], + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + chunk_size: int = 16, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: + """ + Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance. + + Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + and choose one that works the best in your use case. + + In case of any error, this method will throw an exception. + + Parameters + ---------- + tokens : List[int] + Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. + chunk_size : int + Size of each chunk in tokens, must be positive. + state_in : Optional[NumpyArrayOrTorchTensor] + State from previous call of this method. If this is a first pass, set it to None. + state_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). + logits_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self._valid, 'Model was freed' + + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + + if state_in is not None: + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + + state_in_ptr = self._get_data_ptr(state_in) + else: + state_in_ptr = 0 + + if state_out is not None: + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + else: + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) + + if logits_out is not None: + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + else: + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) + + self._library.rwkv_eval_sequence_in_chunks( + self._ctx, + tokens, + chunk_size, + state_in_ptr, + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) + ) + + return logits_out, state_out + + def free(self) -> None: + """ + Frees all allocated resources. + In case of any error, this method will throw an exception. + The object must not be used anymore after calling this method. + """ + + assert self._valid, 'Already freed' + + self._valid = False + + self._library.rwkv_free(self._ctx) + + def __del__(self) -> None: + # Free the context on GC in case user forgot to call free() explicitly. + if hasattr(self, '_valid') and self._valid: + self.free() + + def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool: + return hasattr(tensor, '__module__') and tensor.__module__ == 'torch' + + def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool: + for tensor in tensors: + if tensor is not None: + return False if self._is_pytorch_tensor(tensor) else True + + return use_numpy_by_default + + def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None: + if self._is_pytorch_tensor(tensor): + tensor: torch.Tensor = tensor + assert tensor.device == torch.device('cpu'), f'{name} is not on CPU' + assert tensor.dtype == torch.float32, f'{name} is not of type float32' + assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' + assert tensor.is_contiguous(), f'{name} is not contiguous' + else: + import numpy as np + tensor: np.ndarray = tensor + assert tensor.dtype == np.float32, f'{name} is not of type float32' + assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' + assert tensor.data.contiguous, f'{name} is not contiguous' + + def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor): + if self._is_pytorch_tensor(tensor): + return tensor.data_ptr() + else: + return tensor.ctypes.data + + def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor: + if use_numpy: + import numpy as np + return np.zeros(element_count, dtype=np.float32) + else: + return torch.zeros(element_count, dtype=torch.float32, device='cpu') diff --git a/backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py b/backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py new file mode 100644 index 0000000..60f387f --- /dev/null +++ b/backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py @@ -0,0 +1,444 @@ +import os +import sys +import ctypes +import pathlib +import platform +from typing import Optional, List, Tuple, Callable + +QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = ( + 'Q4_0', + 'Q4_1', + 'Q5_0', + 'Q5_1', + 'Q8_0' +) + +P_FLOAT = ctypes.POINTER(ctypes.c_float) +P_INT = ctypes.POINTER(ctypes.c_int32) + +class RWKVContext: + + def __init__(self, ptr: ctypes.pointer) -> None: + self.ptr: ctypes.pointer = ptr + +class RWKVSharedLibrary: + """ + Python wrapper around rwkv.cpp shared library. + """ + + def __init__(self, shared_library_path: str) -> None: + """ + Loads the shared library from specified file. + In case of any error, this method will throw an exception. + + Parameters + ---------- + shared_library_path : str + Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. + """ + # When Python is greater than 3.8, we need to reprocess the custom dll + # according to the documentation to prevent loading failure errors. + # https://docs.python.org/3/whatsnew/3.8.html#ctypes + if platform.system().lower() == 'windows': + self.library = ctypes.CDLL(shared_library_path, winmode=0) + else: + self.library = ctypes.cdll.LoadLibrary(shared_library_path) + + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] + self.library.rwkv_init_from_file.restype = ctypes.c_void_p + + self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] + self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool + + self.library.rwkv_eval.argtypes = [ + ctypes.c_void_p, # ctx + ctypes.c_int32, # token + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval.restype = ctypes.c_bool + + self.library.rwkv_eval_sequence.argtypes = [ + ctypes.c_void_p, # ctx + P_INT, # tokens + ctypes.c_size_t, # token count + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval_sequence.restype = ctypes.c_bool + + self.library.rwkv_eval_sequence_in_chunks.argtypes = [ + ctypes.c_void_p, # ctx + P_INT, # tokens + ctypes.c_size_t, # token count + ctypes.c_size_t, # chunk size + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool + + self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t + + self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_n_embed.restype = ctypes.c_size_t + + self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_n_layer.restype = ctypes.c_size_t + + self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32 + + self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32 + + self.library.rwkv_free.argtypes = [ctypes.c_void_p] + self.library.rwkv_free.restype = None + + self.library.rwkv_free.argtypes = [ctypes.c_void_p] + self.library.rwkv_free.restype = None + + self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p] + self.library.rwkv_quantize_model_file.restype = ctypes.c_bool + + self.library.rwkv_get_system_info_string.argtypes = [] + self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p + + self.nullptr = ctypes.cast(0, ctypes.c_void_p) + + def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: + """ + Loads the model from a file and prepares it for inference. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + model_file_path : str + Path to model file in ggml format. + thread_count : int + Count of threads to use, must be positive. + """ + + ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) + + assert ptr is not None, 'rwkv_init_from_file failed, check stderr' + + return RWKVContext(ptr) + + def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool: + """ + Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast. + For the purposes of this function, model head (unembedding matrix) is treated as an additional layer: + - pass `rwkv_get_n_layer(ctx)` to offload all layers except model head + - pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head + Returns true if at least one layer was offloaded. + If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + layer_count : int + Count of layers to offload onto the GPU, must be >= 0. + """ + + assert layer_count >= 0, 'Layer count must be >= 0' + + return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count)) + + def rwkv_eval( + self, + ctx: RWKVContext, + token: int, + state_in_address: Optional[int], + state_out_address: int, + logits_out_address: int + ) -> None: + """ + Evaluates the model for a single token. + Throws an exception in case of any error. Error messages would be printed to stderr. + Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + token : int + Next token index, in range 0 <= token < n_vocab. + state_in_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. + state_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + logits_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + """ + + assert self.library.rwkv_eval( + ctx.ptr, + ctypes.c_int32(token), + ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), + ctypes.cast(state_out_address, P_FLOAT), + ctypes.cast(logits_out_address, P_FLOAT) + ), 'rwkv_eval failed, check stderr' + + def rwkv_eval_sequence( + self, + ctx: RWKVContext, + tokens: List[int], + state_in_address: Optional[int], + state_out_address: int, + logits_out_address: int + ) -> None: + """ + Evaluates the model for a sequence of tokens. + Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. + Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. + + NOTE ON GGML NODE LIMIT + + ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes + this limit when using large models and/or large sequence lengths. + Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. + + If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + To get rid of the assertion failure, reduce the model size and/or sequence length. + + Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + tokens : List[int] + Next token indices, in range 0 <= token < n_vocab. + state_in_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. + state_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + logits_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + """ + + assert self.library.rwkv_eval_sequence( + ctx.ptr, + ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), + ctypes.c_size_t(len(tokens)), + ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), + ctypes.cast(state_out_address, P_FLOAT), + ctypes.cast(logits_out_address, P_FLOAT) + ), 'rwkv_eval_sequence failed, check stderr' + + def rwkv_eval_sequence_in_chunks( + self, + ctx: RWKVContext, + tokens: List[int], + chunk_size: int, + state_in_address: Optional[int], + state_out_address: int, + logits_out_address: int + ) -> None: + """ + Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance. + + Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + and choose one that works the best in your use case. + + Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + tokens : List[int] + Next token indices, in range 0 <= token < n_vocab. + chunk_size : int + Size of each chunk in tokens, must be positive. + state_in_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. + state_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + logits_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + """ + + assert self.library.rwkv_eval_sequence_in_chunks( + ctx.ptr, + ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), + ctypes.c_size_t(len(tokens)), + ctypes.c_size_t(chunk_size), + ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), + ctypes.cast(state_out_address, P_FLOAT), + ctypes.cast(logits_out_address, P_FLOAT) + ), 'rwkv_eval_sequence_in_chunks failed, check stderr' + + def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: + """ + Returns the number of tokens in the given model's vocabulary. + Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_n_vocab(ctx.ptr) + + def rwkv_get_n_embed(self, ctx: RWKVContext) -> int: + """ + Returns the number of elements in the given model's embedding. + Useful for reading individual fields of a model's hidden state. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_n_embed(ctx.ptr) + + def rwkv_get_n_layer(self, ctx: RWKVContext) -> int: + """ + Returns the number of layers in the given model. + A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model. + Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`. + Useful for always offloading the entire model to GPU. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_n_layer(ctx.ptr) + + def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int: + """ + Returns count of FP32 elements in state buffer. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_state_buffer_element_count(ctx.ptr) + + def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int: + """ + Returns count of FP32 elements in logits buffer. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr) + + def rwkv_free(self, ctx: RWKVContext) -> None: + """ + Frees all allocated memory and the context. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + self.library.rwkv_free(ctx.ptr) + + ctx.ptr = self.nullptr + + def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None: + """ + Quantizes FP32 or FP16 model to one of INT4 formats. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + model_file_path_in : str + Path to model file in ggml format, must be either FP32 or FP16. + model_file_path_out : str + Quantized model will be written here. + format_name : str + One of QUANTIZED_FORMAT_NAMES. + """ + + assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}' + + assert self.library.rwkv_quantize_model_file( + model_file_path_in.encode('utf-8'), + model_file_path_out.encode('utf-8'), + format_name.encode('utf-8') + ), 'rwkv_quantize_model_file failed, check stderr' + + def rwkv_get_system_info_string(self) -> str: + """ + Returns system information string. + """ + + return self.library.rwkv_get_system_info_string().decode('utf-8') + +def load_rwkv_shared_library() -> RWKVSharedLibrary: + """ + Attempts to find rwkv.cpp shared library and load it. + To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly. + """ + + file_name: str + + if 'win32' in sys.platform or 'cygwin' in sys.platform: + file_name = 'rwkv.dll' + elif 'darwin' in sys.platform: + file_name = 'librwkv.dylib' + else: + file_name = 'librwkv.so' + + # Possible sub-paths to the library relative to the repo dir. + child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [ + # No lookup for Debug config here. + # I assume that if a user wants to debug the library, + # they will be able to find the library and set the exact path explicitly. + lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name, + lambda p: p / 'bin' / 'Release' / file_name, + lambda p: p / 'bin' / file_name, + # Some people prefer to build in the "build" subdirectory. + lambda p: p / 'build' / 'bin' / 'Release' / file_name, + lambda p: p / 'build' / 'bin' / file_name, + lambda p: p / 'build' / file_name, + # Fallback. + lambda p: p / file_name + ] + + working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd())) + + parent_paths: List[pathlib.Path] = [ + # Possible repo dirs relative to the working dir. + # ./python/rwkv_cpp + working_dir.parent.parent, + # ./python + working_dir.parent, + # . + working_dir, + # Repo dir relative to this Python file. + pathlib.Path(os.path.abspath(__file__)).parent.parent.parent + ] + + for parent_path in parent_paths: + for child_path in child_paths: + full_path: pathlib.Path = child_path(parent_path) + + if os.path.isfile(full_path): + return RWKVSharedLibrary(str(full_path)) + + assert False, (f'Failed to find {file_name} automatically; ' + f'you need to find the library and create RWKVSharedLibrary specifying the path to it') diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index b09f230..a0f8ea4 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -78,12 +78,22 @@ class PIPELINE: def decode(self, x): return self.tokenizer.decode(x) + def np_softmax(self, x: np.ndarray, axis: int): + x -= x.max(axis=axis, keepdims=True) + e: np.ndarray = np.exp(x) + return e / e.sum(axis=axis, keepdims=True) + def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): - probs = F.softmax(logits.float(), dim=-1) + np_logits = type(logits) == np.ndarray + if np_logits: + probs = self.np_softmax(logits, axis=-1) + else: + probs = F.softmax(logits.float(), dim=-1) top_k = int(top_k) # 'privateuseone' is the type of custom devices like `torch_directml.device()` - if probs.device.type in ["cpu", "privateuseone"]: - probs = probs.cpu().numpy() + if np_logits or probs.device.type in ["cpu", "privateuseone"]: + if not np_logits: + probs = probs.cpu().numpy() sorted_ids = np.argsort(probs) sorted_probs = probs[sorted_ids][::-1] cumulative_probs = np.cumsum(sorted_probs) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index d6ca77e..ed60e81 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -510,15 +510,22 @@ def get_tokenizer(tokenizer_len: int): def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV: rwkv_beta = global_var.get(global_var.Args).rwkv_beta + rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp") if "midi" in model.lower() or "abc" in model.lower(): os.environ["RWKV_RESCALE_LAYER"] = "999" # dynamic import to make RWKV_CUDA_ON work if rwkv_beta: + print("Using rwkv-beta") from rwkv_pip.beta.model import ( RWKV as Model, ) + elif rwkv_cpp: + print("Using rwkv.cpp, strategy is ignored") + from rwkv_pip.cpp.model import ( + RWKV as Model, + ) else: from rwkv_pip.model import ( RWKV as Model, diff --git a/frontend/src/_locales/ja/main.json b/frontend/src/_locales/ja/main.json index 326c02b..8efdec8 100644 --- a/frontend/src/_locales/ja/main.json +++ b/frontend/src/_locales/ja/main.json @@ -128,7 +128,7 @@ "Chinese Kongfu": "中国武術", "Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)", "Custom": "カスタム", - "CUDA (Beta, Faster)": "CUDA (ベータ、高速)", + "CUDA (Beta, Faster)": "CUDA (Beta, 高速)", "Reset All Configs": "すべての設定をリセット", "Cancel": "キャンセル", "Confirm": "確認", @@ -313,5 +313,8 @@ "Music": "音楽", "Other": "その他", "Import MIDI": "MIDIをインポート", - "Current Instrument": "現在の楽器" + "Current Instrument": "現在の楽器", + "Please convert model to GGML format first": "モデルをGGML形式に変換してください", + "Convert To GGML Format": "GGML形式に変換", + "CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 高速)" } \ No newline at end of file diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index f39640b..1862911 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -313,5 +313,8 @@ "Music": "音乐", "Other": "其他", "Import MIDI": "导入MIDI", - "Current Instrument": "当前乐器" + "Current Instrument": "当前乐器", + "Please convert model to GGML format first": "请先将模型转换为GGML格式", + "Convert To GGML Format": "转换为GGML格式", + "CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 更快)" } \ No newline at end of file diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index 4e5867d..8cd5adf 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -17,7 +17,8 @@ import { ToolTipButton } from './ToolTipButton'; import { Play16Regular, Stop16Regular } from '@fluentui/react-icons'; import { useNavigate } from 'react-router'; import { WindowShow } from '../../wailsjs/runtime'; -import { convertToSt } from '../utils/convert-to-st'; +import { convertToGGML, convertToSt } from '../utils/convert-model'; +import { Precision } from '../types/configs'; const mainButtonText = { [ModelStatus.Offline]: 'Run', @@ -47,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean const modelConfig = commonStore.getCurrentModelConfig(); const webgpu = modelConfig.modelParameters.device === 'WebGPU'; + const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)'; let modelName = ''; let modelPath = ''; if (modelConfig && modelConfig.modelParameters) { @@ -112,6 +114,30 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean return; } + if (cpp) { + if (!['.bin'].some(ext => modelPath.endsWith(ext))) { + const precision: Precision = modelConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16'; + const ggmlModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`); + if (await FileExists(ggmlModelPath)) { + modelPath = ggmlModelPath; + } else if (!await FileExists(modelPath)) { + showDownloadPrompt(t('Model file not found'), modelName); + commonStore.setStatus({ status: ModelStatus.Offline }); + return; + } else if (!currentModelSource?.isComplete) { + showDownloadPrompt(t('Model file download is not complete'), modelName); + commonStore.setStatus({ status: ModelStatus.Offline }); + return; + } else { + toastWithButton(t('Please convert model to GGML format first'), t('Convert'), () => { + convertToGGML(modelConfig, navigate); + }); + commonStore.setStatus({ status: ModelStatus.Offline }); + return; + } + } + } + if (!await FileExists(modelPath)) { showDownloadPrompt(t('Model file not found'), modelName); commonStore.setStatus({ status: ModelStatus.Offline }); @@ -142,7 +168,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta'; startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1', - !!modelConfig.enableWebUI, isUsingCudaBeta + !!modelConfig.enableWebUI, isUsingCudaBeta, cpp ).catch((e) => { const errMsg = e.message || e; if (errMsg.includes('path contains space')) diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index 0a787dd..d9bdc76 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -27,18 +27,19 @@ import { Page } from '../components/Page'; import { useNavigate } from 'react-router'; import { RunButton } from '../components/RunButton'; import { updateConfig } from '../apis'; -import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App'; -import { checkDependencies, getStrategy } from '../utils'; +import { getStrategy } from '../utils'; import { useTranslation } from 'react-i18next'; -import { WindowShow } from '../../wailsjs/runtime'; import strategyImg from '../assets/images/strategy.jpg'; import strategyZhImg from '../assets/images/strategy_zh.jpg'; import { ResetConfigsButton } from '../components/ResetConfigsButton'; import { useMediaQuery } from 'usehooks-ts'; import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs'; -import { convertToSt } from '../utils/convert-to-st'; +import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model'; -const ConfigSelector: FC<{ selectedIndex: number, updateSelectedIndex: (i: number) => void }> = observer(({ selectedIndex, updateSelectedIndex }) => { +const ConfigSelector: FC<{ + selectedIndex: number, + updateSelectedIndex: (i: number) => void +}> = observer(({ selectedIndex, updateSelectedIndex }) => { return ( { } /> { selectedConfig.modelParameters.device !== 'WebGPU' ? - { - if (commonStore.platform === 'darwin') { - toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' }); - return; - } else if (commonStore.platform === 'linux') { - toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' }); - return; - } - - const ok = await checkDependencies(navigate); - if (!ok) - return; - - const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; - if (await FileExists(modelPath)) { - const strategy = getStrategy(selectedConfig); - const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-'); - toast(t('Start Converting'), { autoClose: 1000, type: 'info' }); - ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => { - if (!await FileExists(newModelPath + '.pth')) { - toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); - } else { - toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); - } - }).catch(e => { - const errMsg = e.message || e; - if (errMsg.includes('path contains space')) - toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' }); - else - toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' }); - }); - setTimeout(WindowShow, 1000); - } else { - toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' }); - } - }} /> : - convertModel(selectedConfig, navigate)} /> : + convertToGGML(selectedConfig, navigate)} />) + : convertToSt(selectedConfig)} /> } @@ -299,6 +269,7 @@ const Configs: FC = observer(() => { } }}> + {commonStore.platform === 'darwin' && } @@ -322,9 +293,11 @@ const Configs: FC = observer(() => { }}> {selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' && } - + {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && } {selectedConfig.modelParameters.device === 'WebGPU' && } - {selectedConfig.modelParameters.device !== 'WebGPU' && } + {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && selectedConfig.modelParameters.device !== 'WebGPU' && + } + {selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && } } /> } diff --git a/frontend/src/types/configs.ts b/frontend/src/types/configs.ts index 0814f66..ff156e4 100644 --- a/frontend/src/types/configs.ts +++ b/frontend/src/types/configs.ts @@ -6,8 +6,8 @@ export type ApiParameters = { presencePenalty: number; frequencyPenalty: number; } -export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom'; -export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4'; +export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom'; +export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1'; export type ModelParameters = { // different models can not have the same name modelName: string; diff --git a/frontend/src/utils/convert-model.ts b/frontend/src/utils/convert-model.ts new file mode 100644 index 0000000..6e9519f --- /dev/null +++ b/frontend/src/utils/convert-model.ts @@ -0,0 +1,107 @@ +import { toast } from 'react-toastify'; +import commonStore from '../stores/commonStore'; +import { t } from 'i18next'; +import { + ConvertGGML, + ConvertModel, + ConvertSafetensors, + FileExists, + GetPyError +} from '../../wailsjs/go/backend_golang/App'; +import { WindowShow } from '../../wailsjs/runtime'; +import { ModelConfig, Precision } from '../types/configs'; +import { checkDependencies, getStrategy } from './index'; +import { NavigateFunction } from 'react-router'; + +export const convertModel = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => { + if (commonStore.platform === 'darwin') { + toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' }); + return; + } else if (commonStore.platform === 'linux') { + toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' }); + return; + } + + const ok = await checkDependencies(navigate); + if (!ok) + return; + + const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; + if (await FileExists(modelPath)) { + const strategy = getStrategy(selectedConfig); + const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-'); + toast(t('Start Converting'), { autoClose: 2000, type: 'info' }); + ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => { + if (!await FileExists(newModelPath + '.pth')) { + toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); + } else { + toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); + } + }).catch(e => { + const errMsg = e.message || e; + if (errMsg.includes('path contains space')) + toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' }); + else + toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' }); + }); + setTimeout(WindowShow, 1000); + } else { + toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' }); + } +}; + + +export const convertToSt = async (selectedConfig: ModelConfig) => { + const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; + if (await FileExists(modelPath)) { + toast(t('Start Converting'), { autoClose: 2000, type: 'info' }); + const newModelPath = modelPath.replace(/\.pth$/, '.st'); + ConvertSafetensors(modelPath, newModelPath).then(async () => { + if (!await FileExists(newModelPath)) { + if (commonStore.platform === 'windows' || commonStore.platform === 'linux') + toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); + } else { + toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); + } + }).catch(e => { + const errMsg = e.message || e; + if (errMsg.includes('path contains space')) + toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' }); + else + toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' }); + }); + setTimeout(WindowShow, 1000); + } else { + toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' }); + } +}; + +export const convertToGGML = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => { + const ok = await checkDependencies(navigate); + if (!ok) + return; + + const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; + if (await FileExists(modelPath)) { + toast(t('Start Converting'), { autoClose: 2000, type: 'info' }); + const precision: Precision = selectedConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16'; + const newModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`); + ConvertGGML(commonStore.settings.customPythonPath, modelPath, newModelPath, precision === 'Q5_1').then(async () => { + if (!await FileExists(newModelPath)) { + if (commonStore.platform === 'windows' || commonStore.platform === 'linux') + toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); + } else { + toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); + } + }).catch(e => { + const errMsg = e.message || e; + if (errMsg.includes('path contains space')) + toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' }); + else + toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' }); + }); + setTimeout(WindowShow, 1000); + } else { + toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' }); + } +}; \ No newline at end of file diff --git a/frontend/src/utils/convert-to-st.ts b/frontend/src/utils/convert-to-st.ts deleted file mode 100644 index 32cf68b..0000000 --- a/frontend/src/utils/convert-to-st.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { toast } from 'react-toastify'; -import commonStore from '../stores/commonStore'; -import { t } from 'i18next'; -import { ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App'; -import { WindowShow } from '../../wailsjs/runtime'; -import { ModelConfig } from '../types/configs'; - -export const convertToSt = async (selectedConfig: ModelConfig) => { - const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; - if (await FileExists(modelPath)) { - toast(t('Start Converting'), { autoClose: 2000, type: 'info' }); - const newModelPath = modelPath.replace(/\.pth$/, '.st'); - ConvertSafetensors(modelPath, newModelPath).then(async () => { - if (!await FileExists(newModelPath)) { - if (commonStore.platform === 'windows' || commonStore.platform === 'linux') - toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); - } else { - toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); - } - }).catch(e => { - const errMsg = e.message || e; - if (errMsg.includes('path contains space')) - toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' }); - else - toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' }); - }); - setTimeout(WindowShow, 1000); - } else { - toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' }); - } -}; \ No newline at end of file diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx index 7869cfe..8ae72eb 100644 --- a/frontend/src/utils/index.tsx +++ b/frontend/src/utils/index.tsx @@ -63,7 +63,7 @@ export async function refreshBuiltInModels(readCache: boolean = false) { return cache; } -const modelSuffix = ['.pth', '.st', '.safetensors']; +const modelSuffix = ['.pth', '.st', '.safetensors', '.bin']; export async function refreshLocalModels(cache: { models: ModelSourceItem[] diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts index de1ff57..10e6503 100755 --- a/frontend/wailsjs/go/backend_golang/App.d.ts +++ b/frontend/wailsjs/go/backend_golang/App.d.ts @@ -10,6 +10,8 @@ export function ContinueDownload(arg1:string):Promise; export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Promise; +export function ConvertGGML(arg1:string,arg2:string,arg3:string,arg4:boolean):Promise; + export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise; export function ConvertSafetensors(arg1:string,arg2:string):Promise; @@ -58,7 +60,7 @@ export function RestartApp():Promise; export function SaveJson(arg1:string,arg2:any):Promise; -export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean):Promise; +export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean):Promise; export function StartWebGPUServer(arg1:number,arg2:string):Promise; diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js index e5d93c5..8d16582 100755 --- a/frontend/wailsjs/go/backend_golang/App.js +++ b/frontend/wailsjs/go/backend_golang/App.js @@ -18,6 +18,10 @@ export function ConvertData(arg1, arg2, arg3, arg4) { return window['go']['backend_golang']['App']['ConvertData'](arg1, arg2, arg3, arg4); } +export function ConvertGGML(arg1, arg2, arg3, arg4) { + return window['go']['backend_golang']['App']['ConvertGGML'](arg1, arg2, arg3, arg4); +} + export function ConvertModel(arg1, arg2, arg3, arg4) { return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4); } @@ -114,8 +118,8 @@ export function SaveJson(arg1, arg2) { return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2); } -export function StartServer(arg1, arg2, arg3, arg4, arg5) { - return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5); +export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6) { + return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6); } export function StartWebGPUServer(arg1, arg2) { diff --git a/vendor.yml b/vendor.yml index bc47abd..8b4553c 100644 --- a/vendor.yml +++ b/vendor.yml @@ -3,6 +3,7 @@ - ^backend-python/get-pip\.py - ^backend-python/convert_model\.py - ^backend-python/convert_safetensors\.py +- ^backend-python/convert_pytorch_to_ggml\.py linguist-vendored - ^backend-python/utils/midi\.py - ^build/ - ^finetune/lora/