add rwkv version field

This commit is contained in:
josc146 2024-03-24 22:29:28 +08:00
parent 1d5d012ce4
commit a93610e574
8 changed files with 189 additions and 98 deletions

Binary file not shown.

Binary file not shown.

View File

@ -9,6 +9,9 @@ class RWKV:
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path) self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
self.w = {} # fake weight self.w = {} # fake weight
self.w["emb.weight"] = [0] * self.model.n_vocab self.w["emb.weight"] = [0] * self.model.n_vocab
self.version = (
self.model.arch_version_major + self.model.arch_version_minor / 10
)
def forward(self, tokens: List[int], state: Union[Any, None] = None): def forward(self, tokens: List[int], state: Union[Any, None] = None):
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True) return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)

Binary file not shown.

View File

@ -52,9 +52,14 @@ class RWKVModel:
if 'gpu_layers_count' in kwargs: if 'gpu_layers_count' in kwargs:
gpu_layer_count = kwargs['gpu_layers_count'] gpu_layer_count = kwargs['gpu_layers_count']
assert os.path.isfile(model_path), f'{model_path} is not a file' if not os.path.isfile(model_path):
assert thread_count > 0, 'Thread count must be > 0' raise ValueError(f'{model_path} is not a file')
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
if not (thread_count > 0):
raise ValueError('Thread count must be > 0')
if not (gpu_layer_count >= 0):
raise ValueError('GPU layer count must be >= 0')
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
@ -84,10 +89,19 @@ class RWKVModel:
Count of layers to offload onto the GPU, must be >= 0. Count of layers to offload onto the GPU, must be >= 0.
""" """
assert layer_count >= 0, 'Layer count must be >= 0' if not (layer_count >= 0):
raise ValueError('Layer count must be >= 0')
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count) return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
@property
def arch_version_major(self) -> int:
return self._library.rwkv_get_arch_version_major(self._ctx)
@property
def arch_version_minor(self) -> int:
return self._library.rwkv_get_arch_version_minor(self._ctx)
@property @property
def n_vocab(self) -> int: def n_vocab(self) -> int:
return self._library.rwkv_get_n_vocab(self._ctx) return self._library.rwkv_get_n_vocab(self._ctx)
@ -133,7 +147,8 @@ class RWKVModel:
Logits vector of shape (n_vocab); state for the next step. Logits vector of shape (n_vocab); state for the next step.
""" """
assert self._valid, 'Model was freed' if not self._valid:
raise ValueError('Model was freed')
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
@ -207,7 +222,8 @@ class RWKVModel:
Logits vector of shape (n_vocab); state for the next step. Logits vector of shape (n_vocab); state for the next step.
""" """
assert self._valid, 'Model was freed' if not self._valid:
raise ValueError('Model was freed')
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
@ -281,7 +297,8 @@ class RWKVModel:
Logits vector of shape (n_vocab); state for the next step. Logits vector of shape (n_vocab); state for the next step.
""" """
assert self._valid, 'Model was freed' if not self._valid:
raise ValueError('Model was freed')
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
@ -320,7 +337,8 @@ class RWKVModel:
The object must not be used anymore after calling this method. The object must not be used anymore after calling this method.
""" """
assert self._valid, 'Already freed' if not self._valid:
raise ValueError('Already freed')
self._valid = False self._valid = False
@ -344,16 +362,25 @@ class RWKVModel:
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None: def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
if self._is_pytorch_tensor(tensor): if self._is_pytorch_tensor(tensor):
tensor: torch.Tensor = tensor tensor: torch.Tensor = tensor
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
assert tensor.dtype == torch.float32, f'{name} is not of type float32' if tensor.device != torch.device('cpu'):
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' raise ValueError(f'{name} is not on CPU')
assert tensor.is_contiguous(), f'{name} is not contiguous' if tensor.dtype != torch.float32:
raise ValueError(f'{name} is not of type float32')
if tensor.shape != (size,):
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
if not tensor.is_contiguous():
raise ValueError(f'{name} is not contiguous')
else: else:
import numpy as np import numpy as np
tensor: np.ndarray = tensor tensor: np.ndarray = tensor
assert tensor.dtype == np.float32, f'{name} is not of type float32'
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' if tensor.dtype != np.float32:
assert tensor.data.contiguous, f'{name} is not contiguous' raise ValueError(f'{name} is not of type float32')
if tensor.shape != (size,):
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
if not tensor.data.contiguous:
raise ValueError(f'{name} is not contiguous')
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor): def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
if self._is_pytorch_tensor(tensor): if self._is_pytorch_tensor(tensor):

View File

@ -6,21 +6,22 @@ import platform
from typing import Optional, List, Tuple, Callable from typing import Optional, List, Tuple, Callable
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = ( QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
'Q4_0', "Q4_0",
'Q4_1', "Q4_1",
'Q5_0', "Q5_0",
'Q5_1', "Q5_1",
'Q8_0' "Q8_0",
) )
P_FLOAT = ctypes.POINTER(ctypes.c_float) P_FLOAT = ctypes.POINTER(ctypes.c_float)
P_INT = ctypes.POINTER(ctypes.c_int32) P_INT = ctypes.POINTER(ctypes.c_int32)
class RWKVContext:
class RWKVContext:
def __init__(self, ptr: ctypes.pointer) -> None: def __init__(self, ptr: ctypes.pointer) -> None:
self.ptr: ctypes.pointer = ptr self.ptr: ctypes.pointer = ptr
class RWKVSharedLibrary: class RWKVSharedLibrary:
""" """
Python wrapper around rwkv.cpp shared library. Python wrapper around rwkv.cpp shared library.
@ -39,7 +40,7 @@ class RWKVSharedLibrary:
# When Python is greater than 3.8, we need to reprocess the custom dll # When Python is greater than 3.8, we need to reprocess the custom dll
# according to the documentation to prevent loading failure errors. # according to the documentation to prevent loading failure errors.
# https://docs.python.org/3/whatsnew/3.8.html#ctypes # https://docs.python.org/3/whatsnew/3.8.html#ctypes
if platform.system().lower() == 'windows': if platform.system().lower() == "windows":
self.library = ctypes.CDLL(shared_library_path, winmode=0) self.library = ctypes.CDLL(shared_library_path, winmode=0)
else: else:
self.library = ctypes.cdll.LoadLibrary(shared_library_path) self.library = ctypes.cdll.LoadLibrary(shared_library_path)
@ -47,7 +48,10 @@ class RWKVSharedLibrary:
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] self.library.rwkv_gpu_offload_layers.argtypes = [
ctypes.c_void_p,
ctypes.c_uint32,
]
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
self.library.rwkv_eval.argtypes = [ self.library.rwkv_eval.argtypes = [
@ -55,7 +59,7 @@ class RWKVSharedLibrary:
ctypes.c_int32, # token ctypes.c_int32, # token
P_FLOAT, # state_in P_FLOAT, # state_in
P_FLOAT, # state_out P_FLOAT, # state_out
P_FLOAT # logits_out P_FLOAT, # logits_out
] ]
self.library.rwkv_eval.restype = ctypes.c_bool self.library.rwkv_eval.restype = ctypes.c_bool
@ -65,7 +69,7 @@ class RWKVSharedLibrary:
ctypes.c_size_t, # token count ctypes.c_size_t, # token count
P_FLOAT, # state_in P_FLOAT, # state_in
P_FLOAT, # state_out P_FLOAT, # state_out
P_FLOAT # logits_out P_FLOAT, # logits_out
] ]
self.library.rwkv_eval_sequence.restype = ctypes.c_bool self.library.rwkv_eval_sequence.restype = ctypes.c_bool
@ -76,10 +80,16 @@ class RWKVSharedLibrary:
ctypes.c_size_t, # chunk size ctypes.c_size_t, # chunk size
P_FLOAT, # state_in P_FLOAT, # state_in
P_FLOAT, # state_out P_FLOAT, # state_out
P_FLOAT # logits_out P_FLOAT, # logits_out
] ]
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32
self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
@ -101,7 +111,11 @@ class RWKVSharedLibrary:
self.library.rwkv_free.argtypes = [ctypes.c_void_p] self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None self.library.rwkv_free.restype = None
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p] self.library.rwkv_quantize_model_file.argtypes = [
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_char_p,
]
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.argtypes = []
@ -109,7 +123,9 @@ class RWKVSharedLibrary:
self.nullptr = ctypes.cast(0, ctypes.c_void_p) self.nullptr = ctypes.cast(0, ctypes.c_void_p)
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: def rwkv_init_from_file(
self, model_file_path: str, thread_count: int
) -> RWKVContext:
""" """
Loads the model from a file and prepares it for inference. Loads the model from a file and prepares it for inference.
Throws an exception in case of any error. Error messages would be printed to stderr. Throws an exception in case of any error. Error messages would be printed to stderr.
@ -122,9 +138,12 @@ class RWKVSharedLibrary:
Count of threads to use, must be positive. Count of threads to use, must be positive.
""" """
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) ptr = self.library.rwkv_init_from_file(
model_file_path.encode("utf-8"), ctypes.c_uint32(thread_count)
)
assert ptr is not None, 'rwkv_init_from_file failed, check stderr' if ptr is None:
raise ValueError("rwkv_init_from_file failed, check stderr")
return RWKVContext(ptr) return RWKVContext(ptr)
@ -145,9 +164,12 @@ class RWKVSharedLibrary:
Count of layers to offload onto the GPU, must be >= 0. Count of layers to offload onto the GPU, must be >= 0.
""" """
assert layer_count >= 0, 'Layer count must be >= 0' if not (layer_count >= 0):
raise ValueError("Layer count must be >= 0")
return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count)) return self.library.rwkv_gpu_offload_layers(
ctx.ptr, ctypes.c_uint32(layer_count)
)
def rwkv_eval( def rwkv_eval(
self, self,
@ -155,7 +177,7 @@ class RWKVSharedLibrary:
token: int, token: int,
state_in_address: Optional[int], state_in_address: Optional[int],
state_out_address: int, state_out_address: int,
logits_out_address: int logits_out_address: int,
) -> None: ) -> None:
""" """
Evaluates the model for a single token. Evaluates the model for a single token.
@ -176,13 +198,14 @@ class RWKVSharedLibrary:
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
""" """
assert self.library.rwkv_eval( if not self.library.rwkv_eval(
ctx.ptr, ctx.ptr,
ctypes.c_int32(token), ctypes.c_int32(token),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT) ctypes.cast(logits_out_address, P_FLOAT),
), 'rwkv_eval failed, check stderr' ):
raise ValueError("rwkv_eval failed, check stderr")
def rwkv_eval_sequence( def rwkv_eval_sequence(
self, self,
@ -190,7 +213,7 @@ class RWKVSharedLibrary:
tokens: List[int], tokens: List[int],
state_in_address: Optional[int], state_in_address: Optional[int],
state_out_address: int, state_out_address: int,
logits_out_address: int logits_out_address: int,
) -> None: ) -> None:
""" """
Evaluates the model for a sequence of tokens. Evaluates the model for a sequence of tokens.
@ -223,14 +246,15 @@ class RWKVSharedLibrary:
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
""" """
assert self.library.rwkv_eval_sequence( if not self.library.rwkv_eval_sequence(
ctx.ptr, ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)), ctypes.c_size_t(len(tokens)),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT) ctypes.cast(logits_out_address, P_FLOAT),
), 'rwkv_eval_sequence failed, check stderr' ):
raise ValueError("rwkv_eval_sequence failed, check stderr")
def rwkv_eval_sequence_in_chunks( def rwkv_eval_sequence_in_chunks(
self, self,
@ -239,7 +263,7 @@ class RWKVSharedLibrary:
chunk_size: int, chunk_size: int,
state_in_address: Optional[int], state_in_address: Optional[int],
state_out_address: int, state_out_address: int,
logits_out_address: int logits_out_address: int,
) -> None: ) -> None:
""" """
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
@ -269,15 +293,40 @@ class RWKVSharedLibrary:
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
""" """
assert self.library.rwkv_eval_sequence_in_chunks( if not self.library.rwkv_eval_sequence_in_chunks(
ctx.ptr, ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)), ctypes.c_size_t(len(tokens)),
ctypes.c_size_t(chunk_size), ctypes.c_size_t(chunk_size),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT) ctypes.cast(logits_out_address, P_FLOAT),
), 'rwkv_eval_sequence_in_chunks failed, check stderr' ):
raise ValueError("rwkv_eval_sequence_in_chunks failed, check stderr")
def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int:
"""
Returns the major version used by the given model.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_arch_version_major(ctx.ptr)
def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int:
"""
Returns the minor version used by the given model.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_arch_version_minor(ctx.ptr)
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
""" """
@ -358,7 +407,9 @@ class RWKVSharedLibrary:
ctx.ptr = self.nullptr ctx.ptr = self.nullptr
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None: def rwkv_quantize_model_file(
self, model_file_path_in: str, model_file_path_out: str, format_name: str
) -> None:
""" """
Quantizes FP32 or FP16 model to one of INT4 formats. Quantizes FP32 or FP16 model to one of INT4 formats.
Throws an exception in case of any error. Error messages would be printed to stderr. Throws an exception in case of any error. Error messages would be printed to stderr.
@ -373,20 +424,25 @@ class RWKVSharedLibrary:
One of QUANTIZED_FORMAT_NAMES. One of QUANTIZED_FORMAT_NAMES.
""" """
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}' if format_name not in QUANTIZED_FORMAT_NAMES:
raise ValueError(
f"Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}"
)
assert self.library.rwkv_quantize_model_file( if not self.library.rwkv_quantize_model_file(
model_file_path_in.encode('utf-8'), model_file_path_in.encode("utf-8"),
model_file_path_out.encode('utf-8'), model_file_path_out.encode("utf-8"),
format_name.encode('utf-8') format_name.encode("utf-8"),
), 'rwkv_quantize_model_file failed, check stderr' ):
raise ValueError("rwkv_quantize_model_file failed, check stderr")
def rwkv_get_system_info_string(self) -> str: def rwkv_get_system_info_string(self) -> str:
""" """
Returns system information string. Returns system information string.
""" """
return self.library.rwkv_get_system_info_string().decode('utf-8') return self.library.rwkv_get_system_info_string().decode("utf-8")
def load_rwkv_shared_library() -> RWKVSharedLibrary: def load_rwkv_shared_library() -> RWKVSharedLibrary:
""" """
@ -396,27 +452,27 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
file_name: str file_name: str
if 'win32' in sys.platform or 'cygwin' in sys.platform: if "win32" in sys.platform or "cygwin" in sys.platform:
file_name = 'rwkv.dll' file_name = "rwkv.dll"
elif 'darwin' in sys.platform: elif "darwin" in sys.platform:
file_name = 'librwkv.dylib' file_name = "librwkv.dylib"
else: else:
file_name = 'librwkv.so' file_name = "librwkv.so"
# Possible sub-paths to the library relative to the repo dir. # Possible sub-paths to the library relative to the repo dir.
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [ child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
# No lookup for Debug config here. # No lookup for Debug config here.
# I assume that if a user wants to debug the library, # I assume that if a user wants to debug the library,
# they will be able to find the library and set the exact path explicitly. # they will be able to find the library and set the exact path explicitly.
lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name, lambda p: p / "backend-python" / "rwkv_pip" / "cpp" / file_name,
lambda p: p / 'bin' / 'Release' / file_name, lambda p: p / "bin" / "Release" / file_name,
lambda p: p / 'bin' / file_name, lambda p: p / "bin" / file_name,
# Some people prefer to build in the "build" subdirectory. # Some people prefer to build in the "build" subdirectory.
lambda p: p / 'build' / 'bin' / 'Release' / file_name, lambda p: p / "build" / "bin" / "Release" / file_name,
lambda p: p / 'build' / 'bin' / file_name, lambda p: p / "build" / "bin" / file_name,
lambda p: p / 'build' / file_name, lambda p: p / "build" / file_name,
# Fallback. # Fallback.
lambda p: p / file_name lambda p: p / file_name,
] ]
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd())) working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
@ -430,7 +486,7 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
# . # .
working_dir, working_dir,
# Repo dir relative to this Python file. # Repo dir relative to this Python file.
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent pathlib.Path(os.path.abspath(__file__)).parent.parent.parent,
] ]
for parent_path in parent_paths: for parent_path in parent_paths:
@ -440,5 +496,7 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
if os.path.isfile(full_path): if os.path.isfile(full_path):
return RWKVSharedLibrary(str(full_path)) return RWKVSharedLibrary(str(full_path))
assert False, (f'Failed to find {file_name} automatically; ' raise ValueError(
f'you need to find the library and create RWKVSharedLibrary specifying the path to it') f"Failed to find {file_name} automatically; "
f"you need to find the library and create RWKVSharedLibrary specifying the path to it"
)

View File

@ -18,6 +18,7 @@ class RWKV:
self.w["emb.weight"] = [0] * self.info.num_vocab self.w["emb.weight"] = [0] * self.info.num_vocab
self.version = str(self.info.version).lower() self.version = str(self.info.version).lower()
self.wrp = getattr(wrp, self.version) self.wrp = getattr(wrp, self.version)
self.version = float(self.version.replace("v", ""))
layer = ( layer = (
int(s.lstrip("layer")) int(s.lstrip("layer"))

View File

@ -26,6 +26,7 @@ class AbstractRWKV(ABC):
self.EOS_ID = 0 self.EOS_ID = 0
self.name = "rwkv" self.name = "rwkv"
self.version = 4
self.model = model self.model = model
self.pipeline = pipeline self.pipeline = pipeline
self.model_state = None self.model_state = None
@ -665,6 +666,7 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
else: else:
rwkv = TextRWKV(model, pipeline) rwkv = TextRWKV(model, pipeline)
rwkv.name = filename rwkv.name = filename
rwkv.version = model.version
return rwkv return rwkv