rwkv.cpp(ggml) support
This commit is contained in:
parent
6e29f97881
commit
b14fbc29b7
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -3,6 +3,7 @@ backend-python/wkv_cuda_utils/** linguist-vendored
|
||||
backend-python/get-pip.py linguist-vendored
|
||||
backend-python/convert_model.py linguist-vendored
|
||||
backend-python/convert_safetensors.py linguist-vendored
|
||||
backend-python/convert_pytorch_to_ggml.py linguist-vendored
|
||||
backend-python/utils/midi.py linguist-vendored
|
||||
build/** linguist-vendored
|
||||
finetune/lora/** linguist-vendored
|
||||
|
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
@ -65,6 +65,8 @@ jobs:
|
||||
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../libs" -Destination "py310/libs" -Recurse
|
||||
./py310/python -m pip install cyac==1.9
|
||||
go install github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
del ./backend-python/rwkv_pip/cpp/librwkv.dylib
|
||||
del ./backend-python/rwkv_pip/cpp/librwkv.so
|
||||
(Get-Content -Path ./backend-golang/app.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/app.go
|
||||
make
|
||||
Rename-Item -Path "build/bin/RWKV-Runner.exe" -NewName "RWKV-Runner_windows_x64.exe"
|
||||
@ -93,6 +95,8 @@ jobs:
|
||||
rm ./backend-python/rwkv_pip/rwkv6.pyd
|
||||
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
|
||||
rm ./backend-python/get-pip.py
|
||||
rm ./backend-python/rwkv_pip/cpp/librwkv.dylib
|
||||
rm ./backend-python/rwkv_pip/cpp/rwkv.dll
|
||||
make
|
||||
mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64
|
||||
|
||||
@ -117,6 +121,8 @@ jobs:
|
||||
rm ./backend-python/rwkv_pip/rwkv6.pyd
|
||||
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
|
||||
rm ./backend-python/get-pip.py
|
||||
rm ./backend-python/rwkv_pip/cpp/rwkv.dll
|
||||
rm ./backend-python/rwkv_pip/cpp/librwkv.so
|
||||
make
|
||||
cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt
|
||||
cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool) (string, error) {
|
||||
func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
@ -25,6 +25,9 @@ func (a *App) StartServer(python string, port int, host string, webui bool, rwkv
|
||||
if rwkvBeta {
|
||||
args = append(args, "--rwkv-beta")
|
||||
}
|
||||
if rwkvcpp {
|
||||
args = append(args, "--rwkv.cpp")
|
||||
}
|
||||
args = append(args, "--port", strconv.Itoa(port), "--host", host)
|
||||
return Cmd(args...)
|
||||
}
|
||||
@ -52,6 +55,21 @@ func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, erro
|
||||
return Cmd(args...)
|
||||
}
|
||||
|
||||
func (a *App) ConvertGGML(python string, modelPath string, outPath string, Q51 bool) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dataType := "FP16"
|
||||
if Q51 {
|
||||
dataType = "Q5_1"
|
||||
}
|
||||
return Cmd(python, "./backend-python/convert_pytorch_to_ggml.py", modelPath, outPath, dataType)
|
||||
}
|
||||
|
||||
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
|
169
backend-python/convert_pytorch_to_ggml.py
vendored
Normal file
169
backend-python/convert_pytorch_to_ggml.py
vendored
Normal file
@ -0,0 +1,169 @@
|
||||
# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
|
||||
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16
|
||||
# Get model checkpoints from https://huggingface.co/BlinkDL
|
||||
# See FILE_FORMAT.md for the documentation on the file format.
|
||||
|
||||
import argparse
|
||||
import struct
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file"
|
||||
)
|
||||
parser.add_argument("src_path", help="Path to PyTorch checkpoint file")
|
||||
parser.add_argument(
|
||||
"dest_path", help="Path to rwkv.cpp checkpoint file, will be overwritten"
|
||||
)
|
||||
parser.add_argument(
|
||||
"data_type",
|
||||
help="Data type, FP16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0",
|
||||
type=str,
|
||||
choices=[
|
||||
"FP16",
|
||||
"Q4_0",
|
||||
"Q4_1",
|
||||
"Q5_0",
|
||||
"Q5_1",
|
||||
"Q8_0",
|
||||
],
|
||||
default="FP16",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
|
||||
n_layer: int = 0
|
||||
|
||||
while f"blocks.{n_layer}.ln1.weight" in state_dict:
|
||||
n_layer += 1
|
||||
|
||||
assert n_layer > 0
|
||||
|
||||
return n_layer
|
||||
|
||||
|
||||
def write_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str
|
||||
) -> None:
|
||||
emb_weight: torch.Tensor = state_dict["emb.weight"]
|
||||
|
||||
n_layer: int = get_layer_count(state_dict)
|
||||
n_vocab: int = emb_weight.shape[0]
|
||||
n_embed: int = emb_weight.shape[1]
|
||||
|
||||
is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict
|
||||
is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict
|
||||
|
||||
if is_v5_2:
|
||||
print("Detected RWKV v5.2")
|
||||
elif is_v5_1_or_2:
|
||||
print("Detected RWKV v5.1")
|
||||
else:
|
||||
print("Detected RWKV v4")
|
||||
|
||||
with open(dest_path, "wb") as out_file:
|
||||
is_FP16: bool = data_type == "FP16" or data_type == "float16"
|
||||
|
||||
out_file.write(
|
||||
struct.pack(
|
||||
# Disable padding with '='
|
||||
"=iiiiii",
|
||||
# Magic: 'ggmf' in hex
|
||||
0x67676D66,
|
||||
101,
|
||||
n_vocab,
|
||||
n_embed,
|
||||
n_layer,
|
||||
1 if is_FP16 else 0,
|
||||
)
|
||||
)
|
||||
|
||||
for k in state_dict.keys():
|
||||
tensor: torch.Tensor = state_dict[k].float()
|
||||
|
||||
if ".time_" in k:
|
||||
tensor = tensor.squeeze()
|
||||
|
||||
if is_v5_1_or_2:
|
||||
if ".time_decay" in k:
|
||||
if is_v5_2:
|
||||
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
|
||||
else:
|
||||
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
|
||||
|
||||
if ".time_first" in k:
|
||||
tensor = torch.exp(tensor).reshape(-1, 1, 1)
|
||||
|
||||
if ".time_faaaa" in k:
|
||||
tensor = tensor.unsqueeze(-1)
|
||||
else:
|
||||
if ".time_decay" in k:
|
||||
tensor = -torch.exp(tensor)
|
||||
|
||||
# Keep 1-dim vectors and small matrices in FP32
|
||||
if is_FP16 and len(tensor.shape) > 1 and ".time_" not in k:
|
||||
tensor = tensor.half()
|
||||
|
||||
shape = tensor.shape
|
||||
|
||||
print(f"Writing {k}, shape {shape}, type {tensor.dtype}")
|
||||
|
||||
k_encoded: bytes = k.encode("utf-8")
|
||||
|
||||
out_file.write(
|
||||
struct.pack(
|
||||
"=iii",
|
||||
len(shape),
|
||||
len(k_encoded),
|
||||
1 if tensor.dtype == torch.float16 else 0,
|
||||
)
|
||||
)
|
||||
|
||||
# Dimension order is reversed here:
|
||||
# * PyTorch shape is (x rows, y columns)
|
||||
# * ggml shape is (y elements in a row, x elements in a column)
|
||||
# Both shapes represent the same tensor.
|
||||
for dim in reversed(tensor.shape):
|
||||
out_file.write(struct.pack("=i", dim))
|
||||
|
||||
out_file.write(k_encoded)
|
||||
|
||||
tensor.numpy().tofile(out_file)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
print(f"Reading {args.src_path}")
|
||||
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location="cpu")
|
||||
|
||||
temp_output: str = args.dest_path
|
||||
if args.data_type.startswith("Q"):
|
||||
import re
|
||||
|
||||
temp_output = re.sub(r"Q[4,5,8]_[0,1]", "fp16", temp_output)
|
||||
write_state_dict(state_dict, temp_output, "FP16")
|
||||
if args.data_type.startswith("Q"):
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
from rwkv_pip.cpp import rwkv_cpp_shared_library
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
library.rwkv_quantize_model_file(temp_output, args.dest_path, args.data_type)
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
with open("error.txt", "w") as f:
|
||||
f.write(str(e))
|
@ -32,6 +32,11 @@ def get_args(args: Union[Sequence[str], None] = None):
|
||||
action="store_true",
|
||||
help="whether to use rwkv-beta (default: False)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--rwkv.cpp",
|
||||
action="store_true",
|
||||
help="whether to use rwkv.cpp (default: False)",
|
||||
)
|
||||
args = parser.parse_args(args)
|
||||
|
||||
return args
|
||||
|
@ -49,19 +49,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
if body.model == "":
|
||||
return "success"
|
||||
|
||||
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
|
||||
if not re.match(STRATEGY_REGEX, body.strategy):
|
||||
raise HTTPException(
|
||||
Status.HTTP_400_BAD_REQUEST,
|
||||
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
|
||||
)
|
||||
devices = set(
|
||||
[
|
||||
x.strip().split(" ")[0].replace("cuda:0", "cuda")
|
||||
for x in body.strategy.split("->")
|
||||
]
|
||||
)
|
||||
print(f"Devices: {devices}")
|
||||
print(f"Strategy Devices: {devices}")
|
||||
# if len(devices) > 1:
|
||||
# state_cache.disable_state_cache()
|
||||
# else:
|
||||
|
@ -90,10 +90,15 @@ def add_state(body: AddStateBody):
|
||||
|
||||
try:
|
||||
id: int = trie.insert(body.prompt)
|
||||
devices: List[torch.device] = [tensor.device for tensor in body.state]
|
||||
devices: List[torch.device] = [
|
||||
(tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
|
||||
for tensor in body.state
|
||||
]
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": [tensor.cpu() for tensor in body.state],
|
||||
"state": [tensor.cpu() for tensor in body.state]
|
||||
if hasattr(body.state[0], "device")
|
||||
else copy.deepcopy(body.state),
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"devices": devices,
|
||||
}
|
||||
@ -185,7 +190,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"tokens": v["tokens"],
|
||||
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
|
||||
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
|
||||
if hasattr(v["state"][0], "device")
|
||||
else v["state"],
|
||||
"logits": v["logits"],
|
||||
}
|
||||
else:
|
||||
|
BIN
backend-python/rwkv_pip/cpp/librwkv.dylib
vendored
Normal file
BIN
backend-python/rwkv_pip/cpp/librwkv.dylib
vendored
Normal file
Binary file not shown.
BIN
backend-python/rwkv_pip/cpp/librwkv.so
vendored
Normal file
BIN
backend-python/rwkv_pip/cpp/librwkv.so
vendored
Normal file
Binary file not shown.
14
backend-python/rwkv_pip/cpp/model.py
vendored
Normal file
14
backend-python/rwkv_pip/cpp/model.py
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any, List
|
||||
from . import rwkv_cpp_model
|
||||
from . import rwkv_cpp_shared_library
|
||||
|
||||
|
||||
class RWKV:
|
||||
def __init__(self, model_path: str, strategy=None):
|
||||
self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
|
||||
self.w = {} # fake weight
|
||||
self.w["emb.weight"] = [0] * self.model.n_vocab
|
||||
|
||||
def forward(self, tokens: List[int], state: Any | None):
|
||||
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
|
BIN
backend-python/rwkv_pip/cpp/rwkv.dll
vendored
Normal file
BIN
backend-python/rwkv_pip/cpp/rwkv.dll
vendored
Normal file
Binary file not shown.
369
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
Normal file
369
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
Normal file
@ -0,0 +1,369 @@
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
# Pre-import PyTorch, if available.
|
||||
# This fixes "OSError: [WinError 127] The specified procedure could not be found".
|
||||
try:
|
||||
import torch
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# I'm sure this is not strictly correct, but let's keep this crutch for now.
|
||||
try:
|
||||
import rwkv_cpp_shared_library
|
||||
except ModuleNotFoundError:
|
||||
from . import rwkv_cpp_shared_library
|
||||
|
||||
from typing import TypeVar, Optional, Tuple, List
|
||||
|
||||
# A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
|
||||
NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
|
||||
|
||||
class RWKVModel:
|
||||
"""
|
||||
An RWKV model managed by rwkv.cpp library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
||||
model_path: str,
|
||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
|
||||
gpu_layer_count: int = 0,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Loads the model and prepares it for inference.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library : RWKVSharedLibrary
|
||||
rwkv.cpp shared library.
|
||||
model_path : str
|
||||
Path to RWKV model file in ggml format.
|
||||
thread_count : int
|
||||
Thread count to use. If not set, defaults to CPU count / 2.
|
||||
gpu_layer_count : int
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
See documentation of `gpu_offload_layers` for details about layer offloading.
|
||||
"""
|
||||
|
||||
if 'gpu_layers_count' in kwargs:
|
||||
gpu_layer_count = kwargs['gpu_layers_count']
|
||||
|
||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be > 0'
|
||||
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
|
||||
|
||||
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
|
||||
|
||||
self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
|
||||
|
||||
if gpu_layer_count > 0:
|
||||
self.gpu_offload_layers(gpu_layer_count)
|
||||
|
||||
self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
||||
self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
||||
|
||||
self._valid: bool = True
|
||||
|
||||
def gpu_offload_layers(self, layer_count: int) -> bool:
|
||||
"""
|
||||
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
|
||||
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
|
||||
- pass `model.n_layer` to offload all layers except model head
|
||||
- pass `model.n_layer + 1` to offload all layers, including model head
|
||||
|
||||
Returns true if at least one layer was offloaded.
|
||||
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layer_count : int
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
"""
|
||||
|
||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
||||
|
||||
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
|
||||
|
||||
@property
|
||||
def n_vocab(self) -> int:
|
||||
return self._library.rwkv_get_n_vocab(self._ctx)
|
||||
|
||||
@property
|
||||
def n_embed(self) -> int:
|
||||
return self._library.rwkv_get_n_embed(self._ctx)
|
||||
|
||||
@property
|
||||
def n_layer(self) -> int:
|
||||
return self._library.rwkv_get_n_layer(self._ctx)
|
||||
|
||||
def eval(
|
||||
self,
|
||||
token: int,
|
||||
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
||||
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
use_numpy: bool = False
|
||||
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token : int
|
||||
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[NumpyArrayOrTorchTensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
use_numpy : bool
|
||||
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
||||
This parameter is ignored if any tensor parameter is not None; in such case,
|
||||
type of returned tensors will match the type of received tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
if state_in is not None:
|
||||
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = self._get_data_ptr(state_in)
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
||||
|
||||
if logits_out is not None:
|
||||
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
||||
|
||||
self._library.rwkv_eval(
|
||||
self._ctx,
|
||||
token,
|
||||
state_in_ptr,
|
||||
self._get_data_ptr(state_out),
|
||||
self._get_data_ptr(logits_out)
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def eval_sequence(
|
||||
self,
|
||||
tokens: List[int],
|
||||
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
||||
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
use_numpy: bool = False
|
||||
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens.
|
||||
|
||||
NOTE ON GGML NODE LIMIT
|
||||
|
||||
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
|
||||
this limit when using large models and/or large sequence lengths.
|
||||
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
|
||||
|
||||
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
|
||||
To get rid of the assertion failure, reduce the model size and/or sequence length.
|
||||
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : List[int]
|
||||
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[NumpyArrayOrTorchTensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
use_numpy : bool
|
||||
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
||||
This parameter is ignored if any tensor parameter is not None; in such case,
|
||||
type of returned tensors will match the type of received tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
if state_in is not None:
|
||||
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = self._get_data_ptr(state_in)
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
||||
|
||||
if logits_out is not None:
|
||||
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
||||
|
||||
self._library.rwkv_eval_sequence(
|
||||
self._ctx,
|
||||
tokens,
|
||||
state_in_ptr,
|
||||
self._get_data_ptr(state_out),
|
||||
self._get_data_ptr(logits_out)
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def eval_sequence_in_chunks(
|
||||
self,
|
||||
tokens: List[int],
|
||||
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
||||
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
chunk_size: int = 16,
|
||||
use_numpy: bool = False
|
||||
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
||||
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
|
||||
It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
|
||||
|
||||
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
|
||||
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
|
||||
and choose one that works the best in your use case.
|
||||
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : List[int]
|
||||
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
chunk_size : int
|
||||
Size of each chunk in tokens, must be positive.
|
||||
state_in : Optional[NumpyArrayOrTorchTensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
use_numpy : bool
|
||||
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
||||
This parameter is ignored if any tensor parameter is not None; in such case,
|
||||
type of returned tensors will match the type of received tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
if state_in is not None:
|
||||
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = self._get_data_ptr(state_in)
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
||||
|
||||
if logits_out is not None:
|
||||
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
||||
|
||||
self._library.rwkv_eval_sequence_in_chunks(
|
||||
self._ctx,
|
||||
tokens,
|
||||
chunk_size,
|
||||
state_in_ptr,
|
||||
self._get_data_ptr(state_out),
|
||||
self._get_data_ptr(logits_out)
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def free(self) -> None:
|
||||
"""
|
||||
Frees all allocated resources.
|
||||
In case of any error, this method will throw an exception.
|
||||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Already freed'
|
||||
|
||||
self._valid = False
|
||||
|
||||
self._library.rwkv_free(self._ctx)
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Free the context on GC in case user forgot to call free() explicitly.
|
||||
if hasattr(self, '_valid') and self._valid:
|
||||
self.free()
|
||||
|
||||
def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
|
||||
return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
|
||||
|
||||
def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
|
||||
for tensor in tensors:
|
||||
if tensor is not None:
|
||||
return False if self._is_pytorch_tensor(tensor) else True
|
||||
|
||||
return use_numpy_by_default
|
||||
|
||||
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
|
||||
if self._is_pytorch_tensor(tensor):
|
||||
tensor: torch.Tensor = tensor
|
||||
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
|
||||
assert tensor.dtype == torch.float32, f'{name} is not of type float32'
|
||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
||||
assert tensor.is_contiguous(), f'{name} is not contiguous'
|
||||
else:
|
||||
import numpy as np
|
||||
tensor: np.ndarray = tensor
|
||||
assert tensor.dtype == np.float32, f'{name} is not of type float32'
|
||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
||||
assert tensor.data.contiguous, f'{name} is not contiguous'
|
||||
|
||||
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
|
||||
if self._is_pytorch_tensor(tensor):
|
||||
return tensor.data_ptr()
|
||||
else:
|
||||
return tensor.ctypes.data
|
||||
|
||||
def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
|
||||
if use_numpy:
|
||||
import numpy as np
|
||||
return np.zeros(element_count, dtype=np.float32)
|
||||
else:
|
||||
return torch.zeros(element_count, dtype=torch.float32, device='cpu')
|
444
backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py
vendored
Normal file
444
backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py
vendored
Normal file
@ -0,0 +1,444 @@
|
||||
import os
|
||||
import sys
|
||||
import ctypes
|
||||
import pathlib
|
||||
import platform
|
||||
from typing import Optional, List, Tuple, Callable
|
||||
|
||||
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
|
||||
'Q4_0',
|
||||
'Q4_1',
|
||||
'Q5_0',
|
||||
'Q5_1',
|
||||
'Q8_0'
|
||||
)
|
||||
|
||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||
P_INT = ctypes.POINTER(ctypes.c_int32)
|
||||
|
||||
class RWKVContext:
|
||||
|
||||
def __init__(self, ptr: ctypes.pointer) -> None:
|
||||
self.ptr: ctypes.pointer = ptr
|
||||
|
||||
class RWKVSharedLibrary:
|
||||
"""
|
||||
Python wrapper around rwkv.cpp shared library.
|
||||
"""
|
||||
|
||||
def __init__(self, shared_library_path: str) -> None:
|
||||
"""
|
||||
Loads the shared library from specified file.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library_path : str
|
||||
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||
"""
|
||||
# When Python is greater than 3.8, we need to reprocess the custom dll
|
||||
# according to the documentation to prevent loading failure errors.
|
||||
# https://docs.python.org/3/whatsnew/3.8.html#ctypes
|
||||
if platform.system().lower() == 'windows':
|
||||
self.library = ctypes.CDLL(shared_library_path, winmode=0)
|
||||
else:
|
||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||
|
||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||
|
||||
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
|
||||
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_eval.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
ctypes.c_int32, # token
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_eval_sequence.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
P_INT, # tokens
|
||||
ctypes.c_size_t, # token count
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval_sequence.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_eval_sequence_in_chunks.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
P_INT, # tokens
|
||||
ctypes.c_size_t, # token count
|
||||
ctypes.c_size_t, # chunk size
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_n_embed.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_n_layer.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
|
||||
|
||||
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
|
||||
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_system_info_string.argtypes = []
|
||||
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
||||
|
||||
self.nullptr = ctypes.cast(0, ctypes.c_void_p)
|
||||
|
||||
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
||||
"""
|
||||
Loads the model from a file and prepares it for inference.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_file_path : str
|
||||
Path to model file in ggml format.
|
||||
thread_count : int
|
||||
Count of threads to use, must be positive.
|
||||
"""
|
||||
|
||||
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
||||
|
||||
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
||||
|
||||
return RWKVContext(ptr)
|
||||
|
||||
def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
|
||||
"""
|
||||
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
|
||||
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
|
||||
- pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
|
||||
- pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
|
||||
Returns true if at least one layer was offloaded.
|
||||
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
layer_count : int
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
"""
|
||||
|
||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
||||
|
||||
return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
|
||||
|
||||
def rwkv_eval(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
token: int,
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
token : int
|
||||
Next token index, in range 0 <= token < n_vocab.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval(
|
||||
ctx.ptr,
|
||||
ctypes.c_int32(token),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval failed, check stderr'
|
||||
|
||||
def rwkv_eval_sequence(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
tokens: List[int],
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens.
|
||||
Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
|
||||
Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
||||
|
||||
NOTE ON GGML NODE LIMIT
|
||||
|
||||
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
|
||||
this limit when using large models and/or large sequence lengths.
|
||||
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
|
||||
|
||||
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
|
||||
To get rid of the assertion failure, reduce the model size and/or sequence length.
|
||||
|
||||
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
tokens : List[int]
|
||||
Next token indices, in range 0 <= token < n_vocab.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval_sequence(
|
||||
ctx.ptr,
|
||||
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
||||
ctypes.c_size_t(len(tokens)),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval_sequence failed, check stderr'
|
||||
|
||||
def rwkv_eval_sequence_in_chunks(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
tokens: List[int],
|
||||
chunk_size: int,
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
||||
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
|
||||
It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
|
||||
|
||||
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
|
||||
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
|
||||
and choose one that works the best in your use case.
|
||||
|
||||
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
tokens : List[int]
|
||||
Next token indices, in range 0 <= token < n_vocab.
|
||||
chunk_size : int
|
||||
Size of each chunk in tokens, must be positive.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval_sequence_in_chunks(
|
||||
ctx.ptr,
|
||||
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
||||
ctypes.c_size_t(len(tokens)),
|
||||
ctypes.c_size_t(chunk_size),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval_sequence_in_chunks failed, check stderr'
|
||||
|
||||
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given model's vocabulary.
|
||||
Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_n_vocab(ctx.ptr)
|
||||
|
||||
def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns the number of elements in the given model's embedding.
|
||||
Useful for reading individual fields of a model's hidden state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_n_embed(ctx.ptr)
|
||||
|
||||
def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns the number of layers in the given model.
|
||||
A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
|
||||
Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
|
||||
Useful for always offloading the entire model to GPU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_n_layer(ctx.ptr)
|
||||
|
||||
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns count of FP32 elements in state buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
|
||||
|
||||
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns count of FP32 elements in logits buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
|
||||
|
||||
def rwkv_free(self, ctx: RWKVContext) -> None:
|
||||
"""
|
||||
Frees all allocated memory and the context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
self.library.rwkv_free(ctx.ptr)
|
||||
|
||||
ctx.ptr = self.nullptr
|
||||
|
||||
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
|
||||
"""
|
||||
Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_file_path_in : str
|
||||
Path to model file in ggml format, must be either FP32 or FP16.
|
||||
model_file_path_out : str
|
||||
Quantized model will be written here.
|
||||
format_name : str
|
||||
One of QUANTIZED_FORMAT_NAMES.
|
||||
"""
|
||||
|
||||
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
|
||||
|
||||
assert self.library.rwkv_quantize_model_file(
|
||||
model_file_path_in.encode('utf-8'),
|
||||
model_file_path_out.encode('utf-8'),
|
||||
format_name.encode('utf-8')
|
||||
), 'rwkv_quantize_model_file failed, check stderr'
|
||||
|
||||
def rwkv_get_system_info_string(self) -> str:
|
||||
"""
|
||||
Returns system information string.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
||||
|
||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||
"""
|
||||
Attempts to find rwkv.cpp shared library and load it.
|
||||
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
|
||||
"""
|
||||
|
||||
file_name: str
|
||||
|
||||
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
||||
file_name = 'rwkv.dll'
|
||||
elif 'darwin' in sys.platform:
|
||||
file_name = 'librwkv.dylib'
|
||||
else:
|
||||
file_name = 'librwkv.so'
|
||||
|
||||
# Possible sub-paths to the library relative to the repo dir.
|
||||
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
|
||||
# No lookup for Debug config here.
|
||||
# I assume that if a user wants to debug the library,
|
||||
# they will be able to find the library and set the exact path explicitly.
|
||||
lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name,
|
||||
lambda p: p / 'bin' / 'Release' / file_name,
|
||||
lambda p: p / 'bin' / file_name,
|
||||
# Some people prefer to build in the "build" subdirectory.
|
||||
lambda p: p / 'build' / 'bin' / 'Release' / file_name,
|
||||
lambda p: p / 'build' / 'bin' / file_name,
|
||||
lambda p: p / 'build' / file_name,
|
||||
# Fallback.
|
||||
lambda p: p / file_name
|
||||
]
|
||||
|
||||
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
|
||||
|
||||
parent_paths: List[pathlib.Path] = [
|
||||
# Possible repo dirs relative to the working dir.
|
||||
# ./python/rwkv_cpp
|
||||
working_dir.parent.parent,
|
||||
# ./python
|
||||
working_dir.parent,
|
||||
# .
|
||||
working_dir,
|
||||
# Repo dir relative to this Python file.
|
||||
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
]
|
||||
|
||||
for parent_path in parent_paths:
|
||||
for child_path in child_paths:
|
||||
full_path: pathlib.Path = child_path(parent_path)
|
||||
|
||||
if os.path.isfile(full_path):
|
||||
return RWKVSharedLibrary(str(full_path))
|
||||
|
||||
assert False, (f'Failed to find {file_name} automatically; '
|
||||
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
|
16
backend-python/rwkv_pip/utils.py
vendored
16
backend-python/rwkv_pip/utils.py
vendored
@ -78,12 +78,22 @@ class PIPELINE:
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def np_softmax(self, x: np.ndarray, axis: int):
|
||||
x -= x.max(axis=axis, keepdims=True)
|
||||
e: np.ndarray = np.exp(x)
|
||||
return e / e.sum(axis=axis, keepdims=True)
|
||||
|
||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
np_logits = type(logits) == np.ndarray
|
||||
if np_logits:
|
||||
probs = self.np_softmax(logits, axis=-1)
|
||||
else:
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
top_k = int(top_k)
|
||||
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
||||
if probs.device.type in ["cpu", "privateuseone"]:
|
||||
probs = probs.cpu().numpy()
|
||||
if np_logits or probs.device.type in ["cpu", "privateuseone"]:
|
||||
if not np_logits:
|
||||
probs = probs.cpu().numpy()
|
||||
sorted_ids = np.argsort(probs)
|
||||
sorted_probs = probs[sorted_ids][::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
|
@ -510,15 +510,22 @@ def get_tokenizer(tokenizer_len: int):
|
||||
|
||||
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
|
||||
|
||||
if "midi" in model.lower() or "abc" in model.lower():
|
||||
os.environ["RWKV_RESCALE_LAYER"] = "999"
|
||||
|
||||
# dynamic import to make RWKV_CUDA_ON work
|
||||
if rwkv_beta:
|
||||
print("Using rwkv-beta")
|
||||
from rwkv_pip.beta.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
elif rwkv_cpp:
|
||||
print("Using rwkv.cpp, strategy is ignored")
|
||||
from rwkv_pip.cpp.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
else:
|
||||
from rwkv_pip.model import (
|
||||
RWKV as Model,
|
||||
|
@ -128,7 +128,7 @@
|
||||
"Chinese Kongfu": "中国武術",
|
||||
"Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)",
|
||||
"Custom": "カスタム",
|
||||
"CUDA (Beta, Faster)": "CUDA (ベータ、高速)",
|
||||
"CUDA (Beta, Faster)": "CUDA (Beta, 高速)",
|
||||
"Reset All Configs": "すべての設定をリセット",
|
||||
"Cancel": "キャンセル",
|
||||
"Confirm": "確認",
|
||||
@ -313,5 +313,8 @@
|
||||
"Music": "音楽",
|
||||
"Other": "その他",
|
||||
"Import MIDI": "MIDIをインポート",
|
||||
"Current Instrument": "現在の楽器"
|
||||
"Current Instrument": "現在の楽器",
|
||||
"Please convert model to GGML format first": "モデルをGGML形式に変換してください",
|
||||
"Convert To GGML Format": "GGML形式に変換",
|
||||
"CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 高速)"
|
||||
}
|
@ -313,5 +313,8 @@
|
||||
"Music": "音乐",
|
||||
"Other": "其他",
|
||||
"Import MIDI": "导入MIDI",
|
||||
"Current Instrument": "当前乐器"
|
||||
"Current Instrument": "当前乐器",
|
||||
"Please convert model to GGML format first": "请先将模型转换为GGML格式",
|
||||
"Convert To GGML Format": "转换为GGML格式",
|
||||
"CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 更快)"
|
||||
}
|
@ -17,7 +17,8 @@ import { ToolTipButton } from './ToolTipButton';
|
||||
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import { convertToSt } from '../utils/convert-to-st';
|
||||
import { convertToGGML, convertToSt } from '../utils/convert-model';
|
||||
import { Precision } from '../types/configs';
|
||||
|
||||
const mainButtonText = {
|
||||
[ModelStatus.Offline]: 'Run',
|
||||
@ -47,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
const modelConfig = commonStore.getCurrentModelConfig();
|
||||
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
|
||||
const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)';
|
||||
let modelName = '';
|
||||
let modelPath = '';
|
||||
if (modelConfig && modelConfig.modelParameters) {
|
||||
@ -112,6 +114,30 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
return;
|
||||
}
|
||||
|
||||
if (cpp) {
|
||||
if (!['.bin'].some(ext => modelPath.endsWith(ext))) {
|
||||
const precision: Precision = modelConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
|
||||
const ggmlModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
|
||||
if (await FileExists(ggmlModelPath)) {
|
||||
modelPath = ggmlModelPath;
|
||||
} else if (!await FileExists(modelPath)) {
|
||||
showDownloadPrompt(t('Model file not found'), modelName);
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
} else if (!currentModelSource?.isComplete) {
|
||||
showDownloadPrompt(t('Model file download is not complete'), modelName);
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
} else {
|
||||
toastWithButton(t('Please convert model to GGML format first'), t('Convert'), () => {
|
||||
convertToGGML(modelConfig, navigate);
|
||||
});
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!await FileExists(modelPath)) {
|
||||
showDownloadPrompt(t('Model file not found'), modelName);
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
@ -142,7 +168,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta';
|
||||
|
||||
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
|
||||
!!modelConfig.enableWebUI, isUsingCudaBeta
|
||||
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp
|
||||
).catch((e) => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
|
@ -27,18 +27,19 @@ import { Page } from '../components/Page';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { RunButton } from '../components/RunButton';
|
||||
import { updateConfig } from '../apis';
|
||||
import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
|
||||
import { checkDependencies, getStrategy } from '../utils';
|
||||
import { getStrategy } from '../utils';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import strategyImg from '../assets/images/strategy.jpg';
|
||||
import strategyZhImg from '../assets/images/strategy_zh.jpg';
|
||||
import { ResetConfigsButton } from '../components/ResetConfigsButton';
|
||||
import { useMediaQuery } from 'usehooks-ts';
|
||||
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
|
||||
import { convertToSt } from '../utils/convert-to-st';
|
||||
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
|
||||
|
||||
const ConfigSelector: FC<{ selectedIndex: number, updateSelectedIndex: (i: number) => void }> = observer(({ selectedIndex, updateSelectedIndex }) => {
|
||||
const ConfigSelector: FC<{
|
||||
selectedIndex: number,
|
||||
updateSelectedIndex: (i: number) => void
|
||||
}> = observer(({ selectedIndex, updateSelectedIndex }) => {
|
||||
return (
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow" value={commonStore.modelConfigs[selectedIndex].name}
|
||||
selectedOptions={[selectedIndex.toString()]}
|
||||
@ -246,45 +247,14 @@ const Configs: FC = observer(() => {
|
||||
} />
|
||||
{
|
||||
selectedConfig.modelParameters.device !== 'WebGPU' ?
|
||||
<ToolTipButton text={t('Convert')}
|
||||
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
||||
onClick={async () => {
|
||||
if (commonStore.platform === 'darwin') {
|
||||
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
} else if (commonStore.platform === 'linux') {
|
||||
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
}
|
||||
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
const strategy = getStrategy(selectedConfig);
|
||||
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
|
||||
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
|
||||
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath + '.pth')) {
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
}} /> :
|
||||
<ToolTipButton text={t('Convert To Safe Tensors Format')}
|
||||
(selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' ?
|
||||
<ToolTipButton text={t('Convert')}
|
||||
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
||||
onClick={() => convertModel(selectedConfig, navigate)} /> :
|
||||
<ToolTipButton text={t('Convert To GGML Format')}
|
||||
desc=""
|
||||
onClick={() => convertToGGML(selectedConfig, navigate)} />)
|
||||
: <ToolTipButton text={t('Convert To Safe Tensors Format')}
|
||||
desc=""
|
||||
onClick={() => convertToSt(selectedConfig)} />
|
||||
}
|
||||
@ -299,6 +269,7 @@ const Configs: FC = observer(() => {
|
||||
}
|
||||
}}>
|
||||
<Option value="CPU">CPU</Option>
|
||||
<Option value="CPU (rwkv.cpp)">{t('CPU (rwkv.cpp, Faster)')!}</Option>
|
||||
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
|
||||
<Option value="CUDA">CUDA</Option>
|
||||
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
|
||||
@ -322,9 +293,11 @@ const Configs: FC = observer(() => {
|
||||
}}>
|
||||
{selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
|
||||
<Option>fp16</Option>}
|
||||
<Option>int8</Option>
|
||||
{selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && <Option>int8</Option>}
|
||||
{selectedConfig.modelParameters.device === 'WebGPU' && <Option>nf4</Option>}
|
||||
{selectedConfig.modelParameters.device !== 'WebGPU' && <Option>fp32</Option>}
|
||||
{selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && selectedConfig.modelParameters.device !== 'WebGPU' &&
|
||||
<Option>fp32</Option>}
|
||||
{selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && <Option>Q5_1</Option>}
|
||||
</Dropdown>
|
||||
} />
|
||||
}
|
||||
|
@ -6,8 +6,8 @@ export type ApiParameters = {
|
||||
presencePenalty: number;
|
||||
frequencyPenalty: number;
|
||||
}
|
||||
export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4';
|
||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
||||
export type ModelParameters = {
|
||||
// different models can not have the same name
|
||||
modelName: string;
|
||||
|
107
frontend/src/utils/convert-model.ts
Normal file
107
frontend/src/utils/convert-model.ts
Normal file
@ -0,0 +1,107 @@
|
||||
import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { t } from 'i18next';
|
||||
import {
|
||||
ConvertGGML,
|
||||
ConvertModel,
|
||||
ConvertSafetensors,
|
||||
FileExists,
|
||||
GetPyError
|
||||
} from '../../wailsjs/go/backend_golang/App';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import { ModelConfig, Precision } from '../types/configs';
|
||||
import { checkDependencies, getStrategy } from './index';
|
||||
import { NavigateFunction } from 'react-router';
|
||||
|
||||
export const convertModel = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
|
||||
if (commonStore.platform === 'darwin') {
|
||||
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
} else if (commonStore.platform === 'linux') {
|
||||
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
}
|
||||
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
const strategy = getStrategy(selectedConfig);
|
||||
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath + '.pth')) {
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const convertToSt = async (selectedConfig: ModelConfig) => {
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
const newModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
ConvertSafetensors(modelPath, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
||||
|
||||
export const convertToGGML = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
const precision: Precision = selectedConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
|
||||
const newModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
|
||||
ConvertGGML(commonStore.settings.customPythonPath, modelPath, newModelPath, precision === 'Q5_1').then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
@ -1,31 +0,0 @@
|
||||
import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { t } from 'i18next';
|
||||
import { ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import { ModelConfig } from '../types/configs';
|
||||
|
||||
export const convertToSt = async (selectedConfig: ModelConfig) => {
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
const newModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
ConvertSafetensors(modelPath, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
@ -63,7 +63,7 @@ export async function refreshBuiltInModels(readCache: boolean = false) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
const modelSuffix = ['.pth', '.st', '.safetensors'];
|
||||
const modelSuffix = ['.pth', '.st', '.safetensors', '.bin'];
|
||||
|
||||
export async function refreshLocalModels(cache: {
|
||||
models: ModelSourceItem[]
|
||||
|
4
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
4
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
@ -10,6 +10,8 @@ export function ContinueDownload(arg1:string):Promise<void>;
|
||||
|
||||
export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||
|
||||
export function ConvertGGML(arg1:string,arg2:string,arg3:string,arg4:boolean):Promise<string>;
|
||||
|
||||
export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||
|
||||
export function ConvertSafetensors(arg1:string,arg2:string):Promise<string>;
|
||||
@ -58,7 +60,7 @@ export function RestartApp():Promise<void>;
|
||||
|
||||
export function SaveJson(arg1:string,arg2:any):Promise<void>;
|
||||
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean):Promise<string>;
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean):Promise<string>;
|
||||
|
||||
export function StartWebGPUServer(arg1:number,arg2:string):Promise<string>;
|
||||
|
||||
|
8
frontend/wailsjs/go/backend_golang/App.js
generated
8
frontend/wailsjs/go/backend_golang/App.js
generated
@ -18,6 +18,10 @@ export function ConvertData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ConvertGGML(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertGGML'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ConvertModel(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@ -114,8 +118,8 @@ export function SaveJson(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function StartServer(arg1, arg2, arg3, arg4, arg5) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5);
|
||||
export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6);
|
||||
}
|
||||
|
||||
export function StartWebGPUServer(arg1, arg2) {
|
||||
|
@ -3,6 +3,7 @@
|
||||
- ^backend-python/get-pip\.py
|
||||
- ^backend-python/convert_model\.py
|
||||
- ^backend-python/convert_safetensors\.py
|
||||
- ^backend-python/convert_pytorch_to_ggml\.py linguist-vendored
|
||||
- ^backend-python/utils/midi\.py
|
||||
- ^build/
|
||||
- ^finetune/lora/
|
||||
|
Loading…
Reference in New Issue
Block a user