add WebGPU Python Mode (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
2
backend-python/rwkv_pip/utils.py
vendored
2
backend-python/rwkv_pip/utils.py
vendored
@@ -84,6 +84,8 @@ class PIPELINE:
|
||||
return e / e.sum(axis=axis, keepdims=True)
|
||||
|
||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||
if type(logits) == list:
|
||||
logits = np.array(logits)
|
||||
np_logits = type(logits) == np.ndarray
|
||||
if np_logits:
|
||||
probs = self.np_softmax(logits, axis=-1)
|
||||
|
||||
21
backend-python/rwkv_pip/webgpu/model.py
vendored
Normal file
21
backend-python/rwkv_pip/webgpu/model.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any, List, Union
|
||||
|
||||
try:
|
||||
import web_rwkv_py as wrp
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
from . import web_rwkv_py as wrp
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
|
||||
)
|
||||
|
||||
|
||||
class RWKV:
|
||||
def __init__(self, model_path: str, strategy=None):
|
||||
self.model = wrp.v5.Model(model_path, turbo=False)
|
||||
self.w = {} # fake weight
|
||||
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
|
||||
|
||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||
return wrp.v5.run_one(self.model, tokens, state)
|
||||
Reference in New Issue
Block a user