2023-12-14 18:37:07 +08:00
|
|
|
from typing import Any, List, Union
|
|
|
|
|
|
|
|
try:
|
|
|
|
import web_rwkv_py as wrp
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
try:
|
|
|
|
from . import web_rwkv_py as wrp
|
|
|
|
except ImportError:
|
|
|
|
raise ModuleNotFoundError(
|
|
|
|
"web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV:
|
2023-12-14 20:39:42 +08:00
|
|
|
def __init__(self, model_path: str, strategy: str = None):
|
2024-02-03 20:32:23 +08:00
|
|
|
self.info = wrp.peek_info(model_path)
|
2023-12-14 18:37:07 +08:00
|
|
|
self.w = {} # fake weight
|
2024-02-03 20:32:23 +08:00
|
|
|
self.w["emb.weight"] = [0] * self.info.num_vocab
|
|
|
|
self.version = str(self.info.version).lower()
|
|
|
|
self.wrp = getattr(wrp, self.version)
|
|
|
|
|
|
|
|
args = {
|
|
|
|
"file": model_path,
|
|
|
|
"turbo": True,
|
2024-02-22 00:05:52 +08:00
|
|
|
"quant": 31 if "i8" in strategy else 0,
|
2024-02-03 20:32:23 +08:00
|
|
|
"quant_nf4": 26 if "i4" in strategy else 0,
|
|
|
|
"token_chunk_size": 32,
|
|
|
|
"lora": None,
|
|
|
|
}
|
|
|
|
self.model = self.wrp.Model(**args)
|
2023-12-14 18:37:07 +08:00
|
|
|
|
|
|
|
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
2023-12-28 20:43:57 +08:00
|
|
|
if type(state).__name__ == "BackedState": # memory state
|
2024-02-03 20:32:23 +08:00
|
|
|
gpu_state = self.wrp.ModelState(self.model, 1)
|
2023-12-28 20:43:57 +08:00
|
|
|
gpu_state.load(state)
|
|
|
|
else:
|
|
|
|
gpu_state = state
|
2024-02-03 20:32:23 +08:00
|
|
|
return self.wrp.run_one(self.model, tokens, gpu_state)
|