add rwkv version field
This commit is contained in:
57
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
57
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
@@ -52,9 +52,14 @@ class RWKVModel:
|
||||
if 'gpu_layers_count' in kwargs:
|
||||
gpu_layer_count = kwargs['gpu_layers_count']
|
||||
|
||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be > 0'
|
||||
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f'{model_path} is not a file')
|
||||
|
||||
if not (thread_count > 0):
|
||||
raise ValueError('Thread count must be > 0')
|
||||
|
||||
if not (gpu_layer_count >= 0):
|
||||
raise ValueError('GPU layer count must be >= 0')
|
||||
|
||||
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
|
||||
|
||||
@@ -84,10 +89,19 @@ class RWKVModel:
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
"""
|
||||
|
||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
||||
if not (layer_count >= 0):
|
||||
raise ValueError('Layer count must be >= 0')
|
||||
|
||||
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
|
||||
|
||||
@property
|
||||
def arch_version_major(self) -> int:
|
||||
return self._library.rwkv_get_arch_version_major(self._ctx)
|
||||
|
||||
@property
|
||||
def arch_version_minor(self) -> int:
|
||||
return self._library.rwkv_get_arch_version_minor(self._ctx)
|
||||
|
||||
@property
|
||||
def n_vocab(self) -> int:
|
||||
return self._library.rwkv_get_n_vocab(self._ctx)
|
||||
@@ -133,7 +147,8 @@ class RWKVModel:
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
if not self._valid:
|
||||
raise ValueError('Model was freed')
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
@@ -207,7 +222,8 @@ class RWKVModel:
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
if not self._valid:
|
||||
raise ValueError('Model was freed')
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
@@ -281,7 +297,8 @@ class RWKVModel:
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
if not self._valid:
|
||||
raise ValueError('Model was freed')
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
@@ -320,7 +337,8 @@ class RWKVModel:
|
||||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Already freed'
|
||||
if not self._valid:
|
||||
raise ValueError('Already freed')
|
||||
|
||||
self._valid = False
|
||||
|
||||
@@ -344,16 +362,25 @@ class RWKVModel:
|
||||
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
|
||||
if self._is_pytorch_tensor(tensor):
|
||||
tensor: torch.Tensor = tensor
|
||||
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
|
||||
assert tensor.dtype == torch.float32, f'{name} is not of type float32'
|
||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
||||
assert tensor.is_contiguous(), f'{name} is not contiguous'
|
||||
|
||||
if tensor.device != torch.device('cpu'):
|
||||
raise ValueError(f'{name} is not on CPU')
|
||||
if tensor.dtype != torch.float32:
|
||||
raise ValueError(f'{name} is not of type float32')
|
||||
if tensor.shape != (size,):
|
||||
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
|
||||
if not tensor.is_contiguous():
|
||||
raise ValueError(f'{name} is not contiguous')
|
||||
else:
|
||||
import numpy as np
|
||||
tensor: np.ndarray = tensor
|
||||
assert tensor.dtype == np.float32, f'{name} is not of type float32'
|
||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
||||
assert tensor.data.contiguous, f'{name} is not contiguous'
|
||||
|
||||
if tensor.dtype != np.float32:
|
||||
raise ValueError(f'{name} is not of type float32')
|
||||
if tensor.shape != (size,):
|
||||
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
|
||||
if not tensor.data.contiguous:
|
||||
raise ValueError(f'{name} is not contiguous')
|
||||
|
||||
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
|
||||
if self._is_pytorch_tensor(tensor):
|
||||
|
||||
Reference in New Issue
Block a user