RWKV-Runner/backend-python/rwkv_pip/webgpu/model.py

32 lines
1.0 KiB
Python
Raw Normal View History

from typing import Any, List, Union
try:
import web_rwkv_py as wrp
except ModuleNotFoundError:
try:
from . import web_rwkv_py as wrp
except ImportError:
raise ModuleNotFoundError(
"web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
)
class RWKV:
2023-12-14 20:39:42 +08:00
def __init__(self, model_path: str, strategy: str = None):
self.model = wrp.v5.Model(
model_path,
2023-12-25 20:34:35 +08:00
turbo=True,
2023-12-14 20:39:42 +08:00
quant=32 if "i8" in strategy else None,
quant_nf4=26 if "i4" in strategy else None,
)
self.w = {} # fake weight
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
def forward(self, tokens: List[int], state: Union[Any, None] = None):
2023-12-28 20:43:57 +08:00
if type(state).__name__ == "BackedState": # memory state
gpu_state = wrp.v5.ModelState(self.model, 1)
gpu_state.load(state)
else:
gpu_state = state
return wrp.v5.run_one(self.model, tokens, gpu_state)