51 lines
1.6 KiB
Python
Vendored
51 lines
1.6 KiB
Python
Vendored
from typing import Any, List, Union
|
|
|
|
try:
|
|
import web_rwkv_py as wrp
|
|
except ModuleNotFoundError:
|
|
try:
|
|
from . import web_rwkv_py as wrp
|
|
except ImportError:
|
|
raise ModuleNotFoundError(
|
|
"web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
|
|
)
|
|
|
|
|
|
class RWKV:
|
|
def __init__(self, model_path: str, strategy: str = None):
|
|
layer = (
|
|
int(s.lstrip("layer"))
|
|
for s in strategy.split()
|
|
for s in s.split(",")
|
|
if s.startswith("layer")
|
|
)
|
|
|
|
chunk_size = (
|
|
int(s.lstrip("chunk"))
|
|
for s in strategy.split()
|
|
for s in s.split(",")
|
|
if s.startswith("chunk")
|
|
)
|
|
self.token_chunk_size = next(chunk_size, 32)
|
|
|
|
args = {
|
|
"path": model_path,
|
|
"quant": next(layer, 31) if "i8" in strategy else 0,
|
|
"quant_nf4": next(layer, 26) if "i4" in strategy else 0,
|
|
}
|
|
self.model = wrp.Model(**args)
|
|
self.info = self.model.info()
|
|
self.w = {} # fake weight
|
|
self.w["emb.weight"] = [0] * self.info.num_vocab
|
|
self.version = str(self.info.version).lower()
|
|
self.version = float(self.version.lower().replace("v", ""))
|
|
|
|
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
|
if state is None:
|
|
self.model.clear_state()
|
|
elif type(state).__name__ == "State_Cpu":
|
|
self.model.load_state(state)
|
|
logits = self.model.run(tokens, self.token_chunk_size)
|
|
ret_state = "State_Gpu"
|
|
return logits, ret_state
|