rwkv.cpp(ggml) support
This commit is contained in:
14
backend-python/rwkv_pip/cpp/model.py
vendored
Normal file
14
backend-python/rwkv_pip/cpp/model.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Any, List
|
||||
from . import rwkv_cpp_model
|
||||
from . import rwkv_cpp_shared_library
|
||||
|
||||
|
||||
class RWKV:
|
||||
def __init__(self, model_path: str, strategy=None):
|
||||
self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
|
||||
self.w = {} # fake weight
|
||||
self.w["emb.weight"] = [0] * self.model.n_vocab
|
||||
|
||||
def forward(self, tokens: List[int], state: Any | None):
|
||||
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
|
||||
Reference in New Issue
Block a user