2023-12-12 23:19:18 +08:00
|
|
|
from typing import Any, List, Union
|
2023-12-12 20:29:55 +08:00
|
|
|
from . import rwkv_cpp_model
|
|
|
|
from . import rwkv_cpp_shared_library
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV:
|
|
|
|
def __init__(self, model_path: str, strategy=None):
|
|
|
|
self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
|
|
|
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
|
|
|
|
self.w = {} # fake weight
|
|
|
|
self.w["emb.weight"] = [0] * self.model.n_vocab
|
2024-03-24 22:29:28 +08:00
|
|
|
self.version = (
|
|
|
|
self.model.arch_version_major + self.model.arch_version_minor / 10
|
|
|
|
)
|
2023-12-12 20:29:55 +08:00
|
|
|
|
2023-12-12 23:19:18 +08:00
|
|
|
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
2023-12-12 20:29:55 +08:00
|
|
|
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
|