add rwkv version field

This commit is contained in:
josc146
2024-03-24 22:29:28 +08:00
parent 1d5d012ce4
commit a93610e574
8 changed files with 189 additions and 98 deletions

View File

@@ -9,6 +9,9 @@ class RWKV:
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
self.w = {} # fake weight
self.w["emb.weight"] = [0] * self.model.n_vocab
self.version = (
self.model.arch_version_major + self.model.arch_version_minor / 10
)
def forward(self, tokens: List[int], state: Union[Any, None] = None):
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)