add rwkv version field
This commit is contained in:
parent
1d5d012ce4
commit
a93610e574
BIN
backend-python/rwkv_pip/cpp/librwkv.dylib
vendored
BIN
backend-python/rwkv_pip/cpp/librwkv.dylib
vendored
Binary file not shown.
BIN
backend-python/rwkv_pip/cpp/librwkv.so
vendored
BIN
backend-python/rwkv_pip/cpp/librwkv.so
vendored
Binary file not shown.
3
backend-python/rwkv_pip/cpp/model.py
vendored
3
backend-python/rwkv_pip/cpp/model.py
vendored
@ -9,6 +9,9 @@ class RWKV:
|
|||||||
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
|
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
|
||||||
self.w = {} # fake weight
|
self.w = {} # fake weight
|
||||||
self.w["emb.weight"] = [0] * self.model.n_vocab
|
self.w["emb.weight"] = [0] * self.model.n_vocab
|
||||||
|
self.version = (
|
||||||
|
self.model.arch_version_major + self.model.arch_version_minor / 10
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||||
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
|
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
|
||||||
|
BIN
backend-python/rwkv_pip/cpp/rwkv.dll
vendored
BIN
backend-python/rwkv_pip/cpp/rwkv.dll
vendored
Binary file not shown.
57
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
57
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
@ -52,9 +52,14 @@ class RWKVModel:
|
|||||||
if 'gpu_layers_count' in kwargs:
|
if 'gpu_layers_count' in kwargs:
|
||||||
gpu_layer_count = kwargs['gpu_layers_count']
|
gpu_layer_count = kwargs['gpu_layers_count']
|
||||||
|
|
||||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
if not os.path.isfile(model_path):
|
||||||
assert thread_count > 0, 'Thread count must be > 0'
|
raise ValueError(f'{model_path} is not a file')
|
||||||
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
|
|
||||||
|
if not (thread_count > 0):
|
||||||
|
raise ValueError('Thread count must be > 0')
|
||||||
|
|
||||||
|
if not (gpu_layer_count >= 0):
|
||||||
|
raise ValueError('GPU layer count must be >= 0')
|
||||||
|
|
||||||
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
|
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
|
||||||
|
|
||||||
@ -84,10 +89,19 @@ class RWKVModel:
|
|||||||
Count of layers to offload onto the GPU, must be >= 0.
|
Count of layers to offload onto the GPU, must be >= 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
if not (layer_count >= 0):
|
||||||
|
raise ValueError('Layer count must be >= 0')
|
||||||
|
|
||||||
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
|
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arch_version_major(self) -> int:
|
||||||
|
return self._library.rwkv_get_arch_version_major(self._ctx)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arch_version_minor(self) -> int:
|
||||||
|
return self._library.rwkv_get_arch_version_minor(self._ctx)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_vocab(self) -> int:
|
def n_vocab(self) -> int:
|
||||||
return self._library.rwkv_get_n_vocab(self._ctx)
|
return self._library.rwkv_get_n_vocab(self._ctx)
|
||||||
@ -133,7 +147,8 @@ class RWKVModel:
|
|||||||
Logits vector of shape (n_vocab); state for the next step.
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self._valid, 'Model was freed'
|
if not self._valid:
|
||||||
|
raise ValueError('Model was freed')
|
||||||
|
|
||||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||||
|
|
||||||
@ -207,7 +222,8 @@ class RWKVModel:
|
|||||||
Logits vector of shape (n_vocab); state for the next step.
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self._valid, 'Model was freed'
|
if not self._valid:
|
||||||
|
raise ValueError('Model was freed')
|
||||||
|
|
||||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||||
|
|
||||||
@ -281,7 +297,8 @@ class RWKVModel:
|
|||||||
Logits vector of shape (n_vocab); state for the next step.
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self._valid, 'Model was freed'
|
if not self._valid:
|
||||||
|
raise ValueError('Model was freed')
|
||||||
|
|
||||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||||
|
|
||||||
@ -320,7 +337,8 @@ class RWKVModel:
|
|||||||
The object must not be used anymore after calling this method.
|
The object must not be used anymore after calling this method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self._valid, 'Already freed'
|
if not self._valid:
|
||||||
|
raise ValueError('Already freed')
|
||||||
|
|
||||||
self._valid = False
|
self._valid = False
|
||||||
|
|
||||||
@ -344,16 +362,25 @@ class RWKVModel:
|
|||||||
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
|
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
|
||||||
if self._is_pytorch_tensor(tensor):
|
if self._is_pytorch_tensor(tensor):
|
||||||
tensor: torch.Tensor = tensor
|
tensor: torch.Tensor = tensor
|
||||||
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
|
|
||||||
assert tensor.dtype == torch.float32, f'{name} is not of type float32'
|
if tensor.device != torch.device('cpu'):
|
||||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
raise ValueError(f'{name} is not on CPU')
|
||||||
assert tensor.is_contiguous(), f'{name} is not contiguous'
|
if tensor.dtype != torch.float32:
|
||||||
|
raise ValueError(f'{name} is not of type float32')
|
||||||
|
if tensor.shape != (size,):
|
||||||
|
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
|
||||||
|
if not tensor.is_contiguous():
|
||||||
|
raise ValueError(f'{name} is not contiguous')
|
||||||
else:
|
else:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
tensor: np.ndarray = tensor
|
tensor: np.ndarray = tensor
|
||||||
assert tensor.dtype == np.float32, f'{name} is not of type float32'
|
|
||||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
if tensor.dtype != np.float32:
|
||||||
assert tensor.data.contiguous, f'{name} is not contiguous'
|
raise ValueError(f'{name} is not of type float32')
|
||||||
|
if tensor.shape != (size,):
|
||||||
|
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
|
||||||
|
if not tensor.data.contiguous:
|
||||||
|
raise ValueError(f'{name} is not contiguous')
|
||||||
|
|
||||||
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
|
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
|
||||||
if self._is_pytorch_tensor(tensor):
|
if self._is_pytorch_tensor(tensor):
|
||||||
|
@ -6,21 +6,22 @@ import platform
|
|||||||
from typing import Optional, List, Tuple, Callable
|
from typing import Optional, List, Tuple, Callable
|
||||||
|
|
||||||
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
|
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
|
||||||
'Q4_0',
|
"Q4_0",
|
||||||
'Q4_1',
|
"Q4_1",
|
||||||
'Q5_0',
|
"Q5_0",
|
||||||
'Q5_1',
|
"Q5_1",
|
||||||
'Q8_0'
|
"Q8_0",
|
||||||
)
|
)
|
||||||
|
|
||||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||||
P_INT = ctypes.POINTER(ctypes.c_int32)
|
P_INT = ctypes.POINTER(ctypes.c_int32)
|
||||||
|
|
||||||
class RWKVContext:
|
|
||||||
|
|
||||||
|
class RWKVContext:
|
||||||
def __init__(self, ptr: ctypes.pointer) -> None:
|
def __init__(self, ptr: ctypes.pointer) -> None:
|
||||||
self.ptr: ctypes.pointer = ptr
|
self.ptr: ctypes.pointer = ptr
|
||||||
|
|
||||||
|
|
||||||
class RWKVSharedLibrary:
|
class RWKVSharedLibrary:
|
||||||
"""
|
"""
|
||||||
Python wrapper around rwkv.cpp shared library.
|
Python wrapper around rwkv.cpp shared library.
|
||||||
@ -39,7 +40,7 @@ class RWKVSharedLibrary:
|
|||||||
# When Python is greater than 3.8, we need to reprocess the custom dll
|
# When Python is greater than 3.8, we need to reprocess the custom dll
|
||||||
# according to the documentation to prevent loading failure errors.
|
# according to the documentation to prevent loading failure errors.
|
||||||
# https://docs.python.org/3/whatsnew/3.8.html#ctypes
|
# https://docs.python.org/3/whatsnew/3.8.html#ctypes
|
||||||
if platform.system().lower() == 'windows':
|
if platform.system().lower() == "windows":
|
||||||
self.library = ctypes.CDLL(shared_library_path, winmode=0)
|
self.library = ctypes.CDLL(shared_library_path, winmode=0)
|
||||||
else:
|
else:
|
||||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||||
@ -47,7 +48,10 @@ class RWKVSharedLibrary:
|
|||||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
||||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||||
|
|
||||||
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
|
self.library.rwkv_gpu_offload_layers.argtypes = [
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_uint32,
|
||||||
|
]
|
||||||
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
|
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
|
||||||
|
|
||||||
self.library.rwkv_eval.argtypes = [
|
self.library.rwkv_eval.argtypes = [
|
||||||
@ -55,7 +59,7 @@ class RWKVSharedLibrary:
|
|||||||
ctypes.c_int32, # token
|
ctypes.c_int32, # token
|
||||||
P_FLOAT, # state_in
|
P_FLOAT, # state_in
|
||||||
P_FLOAT, # state_out
|
P_FLOAT, # state_out
|
||||||
P_FLOAT # logits_out
|
P_FLOAT, # logits_out
|
||||||
]
|
]
|
||||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||||
|
|
||||||
@ -65,7 +69,7 @@ class RWKVSharedLibrary:
|
|||||||
ctypes.c_size_t, # token count
|
ctypes.c_size_t, # token count
|
||||||
P_FLOAT, # state_in
|
P_FLOAT, # state_in
|
||||||
P_FLOAT, # state_out
|
P_FLOAT, # state_out
|
||||||
P_FLOAT # logits_out
|
P_FLOAT, # logits_out
|
||||||
]
|
]
|
||||||
self.library.rwkv_eval_sequence.restype = ctypes.c_bool
|
self.library.rwkv_eval_sequence.restype = ctypes.c_bool
|
||||||
|
|
||||||
@ -76,10 +80,16 @@ class RWKVSharedLibrary:
|
|||||||
ctypes.c_size_t, # chunk size
|
ctypes.c_size_t, # chunk size
|
||||||
P_FLOAT, # state_in
|
P_FLOAT, # state_in
|
||||||
P_FLOAT, # state_out
|
P_FLOAT, # state_out
|
||||||
P_FLOAT # logits_out
|
P_FLOAT, # logits_out
|
||||||
]
|
]
|
||||||
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
|
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
|
||||||
|
|
||||||
|
self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32
|
||||||
|
|
||||||
|
self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32
|
||||||
|
|
||||||
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
|
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
|
||||||
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
|
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
|
||||||
|
|
||||||
@ -101,7 +111,11 @@ class RWKVSharedLibrary:
|
|||||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||||
self.library.rwkv_free.restype = None
|
self.library.rwkv_free.restype = None
|
||||||
|
|
||||||
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
|
self.library.rwkv_quantize_model_file.argtypes = [
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
]
|
||||||
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||||
|
|
||||||
self.library.rwkv_get_system_info_string.argtypes = []
|
self.library.rwkv_get_system_info_string.argtypes = []
|
||||||
@ -109,7 +123,9 @@ class RWKVSharedLibrary:
|
|||||||
|
|
||||||
self.nullptr = ctypes.cast(0, ctypes.c_void_p)
|
self.nullptr = ctypes.cast(0, ctypes.c_void_p)
|
||||||
|
|
||||||
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
def rwkv_init_from_file(
|
||||||
|
self, model_file_path: str, thread_count: int
|
||||||
|
) -> RWKVContext:
|
||||||
"""
|
"""
|
||||||
Loads the model from a file and prepares it for inference.
|
Loads the model from a file and prepares it for inference.
|
||||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||||
@ -122,9 +138,12 @@ class RWKVSharedLibrary:
|
|||||||
Count of threads to use, must be positive.
|
Count of threads to use, must be positive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
ptr = self.library.rwkv_init_from_file(
|
||||||
|
model_file_path.encode("utf-8"), ctypes.c_uint32(thread_count)
|
||||||
|
)
|
||||||
|
|
||||||
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
if ptr is None:
|
||||||
|
raise ValueError("rwkv_init_from_file failed, check stderr")
|
||||||
|
|
||||||
return RWKVContext(ptr)
|
return RWKVContext(ptr)
|
||||||
|
|
||||||
@ -145,9 +164,12 @@ class RWKVSharedLibrary:
|
|||||||
Count of layers to offload onto the GPU, must be >= 0.
|
Count of layers to offload onto the GPU, must be >= 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
if not (layer_count >= 0):
|
||||||
|
raise ValueError("Layer count must be >= 0")
|
||||||
|
|
||||||
return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
|
return self.library.rwkv_gpu_offload_layers(
|
||||||
|
ctx.ptr, ctypes.c_uint32(layer_count)
|
||||||
|
)
|
||||||
|
|
||||||
def rwkv_eval(
|
def rwkv_eval(
|
||||||
self,
|
self,
|
||||||
@ -155,7 +177,7 @@ class RWKVSharedLibrary:
|
|||||||
token: int,
|
token: int,
|
||||||
state_in_address: Optional[int],
|
state_in_address: Optional[int],
|
||||||
state_out_address: int,
|
state_out_address: int,
|
||||||
logits_out_address: int
|
logits_out_address: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates the model for a single token.
|
Evaluates the model for a single token.
|
||||||
@ -176,13 +198,14 @@ class RWKVSharedLibrary:
|
|||||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.library.rwkv_eval(
|
if not self.library.rwkv_eval(
|
||||||
ctx.ptr,
|
ctx.ptr,
|
||||||
ctypes.c_int32(token),
|
ctypes.c_int32(token),
|
||||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||||
ctypes.cast(state_out_address, P_FLOAT),
|
ctypes.cast(state_out_address, P_FLOAT),
|
||||||
ctypes.cast(logits_out_address, P_FLOAT)
|
ctypes.cast(logits_out_address, P_FLOAT),
|
||||||
), 'rwkv_eval failed, check stderr'
|
):
|
||||||
|
raise ValueError("rwkv_eval failed, check stderr")
|
||||||
|
|
||||||
def rwkv_eval_sequence(
|
def rwkv_eval_sequence(
|
||||||
self,
|
self,
|
||||||
@ -190,7 +213,7 @@ class RWKVSharedLibrary:
|
|||||||
tokens: List[int],
|
tokens: List[int],
|
||||||
state_in_address: Optional[int],
|
state_in_address: Optional[int],
|
||||||
state_out_address: int,
|
state_out_address: int,
|
||||||
logits_out_address: int
|
logits_out_address: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates the model for a sequence of tokens.
|
Evaluates the model for a sequence of tokens.
|
||||||
@ -223,14 +246,15 @@ class RWKVSharedLibrary:
|
|||||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.library.rwkv_eval_sequence(
|
if not self.library.rwkv_eval_sequence(
|
||||||
ctx.ptr,
|
ctx.ptr,
|
||||||
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
||||||
ctypes.c_size_t(len(tokens)),
|
ctypes.c_size_t(len(tokens)),
|
||||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||||
ctypes.cast(state_out_address, P_FLOAT),
|
ctypes.cast(state_out_address, P_FLOAT),
|
||||||
ctypes.cast(logits_out_address, P_FLOAT)
|
ctypes.cast(logits_out_address, P_FLOAT),
|
||||||
), 'rwkv_eval_sequence failed, check stderr'
|
):
|
||||||
|
raise ValueError("rwkv_eval_sequence failed, check stderr")
|
||||||
|
|
||||||
def rwkv_eval_sequence_in_chunks(
|
def rwkv_eval_sequence_in_chunks(
|
||||||
self,
|
self,
|
||||||
@ -239,7 +263,7 @@ class RWKVSharedLibrary:
|
|||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
state_in_address: Optional[int],
|
state_in_address: Optional[int],
|
||||||
state_out_address: int,
|
state_out_address: int,
|
||||||
logits_out_address: int
|
logits_out_address: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
||||||
@ -269,15 +293,40 @@ class RWKVSharedLibrary:
|
|||||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.library.rwkv_eval_sequence_in_chunks(
|
if not self.library.rwkv_eval_sequence_in_chunks(
|
||||||
ctx.ptr,
|
ctx.ptr,
|
||||||
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
||||||
ctypes.c_size_t(len(tokens)),
|
ctypes.c_size_t(len(tokens)),
|
||||||
ctypes.c_size_t(chunk_size),
|
ctypes.c_size_t(chunk_size),
|
||||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||||
ctypes.cast(state_out_address, P_FLOAT),
|
ctypes.cast(state_out_address, P_FLOAT),
|
||||||
ctypes.cast(logits_out_address, P_FLOAT)
|
ctypes.cast(logits_out_address, P_FLOAT),
|
||||||
), 'rwkv_eval_sequence_in_chunks failed, check stderr'
|
):
|
||||||
|
raise ValueError("rwkv_eval_sequence_in_chunks failed, check stderr")
|
||||||
|
|
||||||
|
def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int:
|
||||||
|
"""
|
||||||
|
Returns the major version used by the given model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : RWKVContext
|
||||||
|
RWKV context obtained from rwkv_init_from_file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.library.rwkv_get_arch_version_major(ctx.ptr)
|
||||||
|
|
||||||
|
def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int:
|
||||||
|
"""
|
||||||
|
Returns the minor version used by the given model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : RWKVContext
|
||||||
|
RWKV context obtained from rwkv_init_from_file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.library.rwkv_get_arch_version_minor(ctx.ptr)
|
||||||
|
|
||||||
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
|
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
|
||||||
"""
|
"""
|
||||||
@ -358,7 +407,9 @@ class RWKVSharedLibrary:
|
|||||||
|
|
||||||
ctx.ptr = self.nullptr
|
ctx.ptr = self.nullptr
|
||||||
|
|
||||||
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
|
def rwkv_quantize_model_file(
|
||||||
|
self, model_file_path_in: str, model_file_path_out: str, format_name: str
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Quantizes FP32 or FP16 model to one of INT4 formats.
|
Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||||
@ -373,20 +424,25 @@ class RWKVSharedLibrary:
|
|||||||
One of QUANTIZED_FORMAT_NAMES.
|
One of QUANTIZED_FORMAT_NAMES.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
|
if format_name not in QUANTIZED_FORMAT_NAMES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}"
|
||||||
|
)
|
||||||
|
|
||||||
assert self.library.rwkv_quantize_model_file(
|
if not self.library.rwkv_quantize_model_file(
|
||||||
model_file_path_in.encode('utf-8'),
|
model_file_path_in.encode("utf-8"),
|
||||||
model_file_path_out.encode('utf-8'),
|
model_file_path_out.encode("utf-8"),
|
||||||
format_name.encode('utf-8')
|
format_name.encode("utf-8"),
|
||||||
), 'rwkv_quantize_model_file failed, check stderr'
|
):
|
||||||
|
raise ValueError("rwkv_quantize_model_file failed, check stderr")
|
||||||
|
|
||||||
def rwkv_get_system_info_string(self) -> str:
|
def rwkv_get_system_info_string(self) -> str:
|
||||||
"""
|
"""
|
||||||
Returns system information string.
|
Returns system information string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
return self.library.rwkv_get_system_info_string().decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||||
"""
|
"""
|
||||||
@ -396,27 +452,27 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
|||||||
|
|
||||||
file_name: str
|
file_name: str
|
||||||
|
|
||||||
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
if "win32" in sys.platform or "cygwin" in sys.platform:
|
||||||
file_name = 'rwkv.dll'
|
file_name = "rwkv.dll"
|
||||||
elif 'darwin' in sys.platform:
|
elif "darwin" in sys.platform:
|
||||||
file_name = 'librwkv.dylib'
|
file_name = "librwkv.dylib"
|
||||||
else:
|
else:
|
||||||
file_name = 'librwkv.so'
|
file_name = "librwkv.so"
|
||||||
|
|
||||||
# Possible sub-paths to the library relative to the repo dir.
|
# Possible sub-paths to the library relative to the repo dir.
|
||||||
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
|
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
|
||||||
# No lookup for Debug config here.
|
# No lookup for Debug config here.
|
||||||
# I assume that if a user wants to debug the library,
|
# I assume that if a user wants to debug the library,
|
||||||
# they will be able to find the library and set the exact path explicitly.
|
# they will be able to find the library and set the exact path explicitly.
|
||||||
lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name,
|
lambda p: p / "backend-python" / "rwkv_pip" / "cpp" / file_name,
|
||||||
lambda p: p / 'bin' / 'Release' / file_name,
|
lambda p: p / "bin" / "Release" / file_name,
|
||||||
lambda p: p / 'bin' / file_name,
|
lambda p: p / "bin" / file_name,
|
||||||
# Some people prefer to build in the "build" subdirectory.
|
# Some people prefer to build in the "build" subdirectory.
|
||||||
lambda p: p / 'build' / 'bin' / 'Release' / file_name,
|
lambda p: p / "build" / "bin" / "Release" / file_name,
|
||||||
lambda p: p / 'build' / 'bin' / file_name,
|
lambda p: p / "build" / "bin" / file_name,
|
||||||
lambda p: p / 'build' / file_name,
|
lambda p: p / "build" / file_name,
|
||||||
# Fallback.
|
# Fallback.
|
||||||
lambda p: p / file_name
|
lambda p: p / file_name,
|
||||||
]
|
]
|
||||||
|
|
||||||
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
|
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
|
||||||
@ -430,7 +486,7 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
|||||||
# .
|
# .
|
||||||
working_dir,
|
working_dir,
|
||||||
# Repo dir relative to this Python file.
|
# Repo dir relative to this Python file.
|
||||||
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
|
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent,
|
||||||
]
|
]
|
||||||
|
|
||||||
for parent_path in parent_paths:
|
for parent_path in parent_paths:
|
||||||
@ -440,5 +496,7 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
|||||||
if os.path.isfile(full_path):
|
if os.path.isfile(full_path):
|
||||||
return RWKVSharedLibrary(str(full_path))
|
return RWKVSharedLibrary(str(full_path))
|
||||||
|
|
||||||
assert False, (f'Failed to find {file_name} automatically; '
|
raise ValueError(
|
||||||
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
|
f"Failed to find {file_name} automatically; "
|
||||||
|
f"you need to find the library and create RWKVSharedLibrary specifying the path to it"
|
||||||
|
)
|
||||||
|
1
backend-python/rwkv_pip/webgpu/model.py
vendored
1
backend-python/rwkv_pip/webgpu/model.py
vendored
@ -18,6 +18,7 @@ class RWKV:
|
|||||||
self.w["emb.weight"] = [0] * self.info.num_vocab
|
self.w["emb.weight"] = [0] * self.info.num_vocab
|
||||||
self.version = str(self.info.version).lower()
|
self.version = str(self.info.version).lower()
|
||||||
self.wrp = getattr(wrp, self.version)
|
self.wrp = getattr(wrp, self.version)
|
||||||
|
self.version = float(self.version.replace("v", ""))
|
||||||
|
|
||||||
layer = (
|
layer = (
|
||||||
int(s.lstrip("layer"))
|
int(s.lstrip("layer"))
|
||||||
|
@ -26,6 +26,7 @@ class AbstractRWKV(ABC):
|
|||||||
self.EOS_ID = 0
|
self.EOS_ID = 0
|
||||||
|
|
||||||
self.name = "rwkv"
|
self.name = "rwkv"
|
||||||
|
self.version = 4
|
||||||
self.model = model
|
self.model = model
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
self.model_state = None
|
self.model_state = None
|
||||||
@ -665,6 +666,7 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
|||||||
else:
|
else:
|
||||||
rwkv = TextRWKV(model, pipeline)
|
rwkv = TextRWKV(model, pipeline)
|
||||||
rwkv.name = filename
|
rwkv.name = filename
|
||||||
|
rwkv.version = model.version
|
||||||
|
|
||||||
return rwkv
|
return rwkv
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user