add WebGPU Python Mode (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
@@ -8,7 +8,6 @@ from typing import Dict, Iterable, List, Tuple, Union, Type
|
||||
from utils.log import quick_log
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
from routes import state_cache
|
||||
import global_var
|
||||
|
||||
@@ -68,6 +67,8 @@ class AbstractRWKV(ABC):
|
||||
pass
|
||||
|
||||
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
|
||||
import numpy as np
|
||||
|
||||
if fast_mode:
|
||||
embedding, token_len = self.__fast_embedding(
|
||||
self.fix_tokens(self.pipeline.encode(input)), None
|
||||
@@ -222,6 +223,8 @@ class AbstractRWKV(ABC):
|
||||
def generate(
|
||||
self, prompt: str, stop: Union[str, List[str], None] = None
|
||||
) -> Iterable[Tuple[str, str, int, int]]:
|
||||
import numpy as np
|
||||
|
||||
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
||||
cache = None
|
||||
delta_prompt = prompt
|
||||
@@ -231,7 +234,7 @@ class AbstractRWKV(ABC):
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
if cache is None or cache["prompt"] == "":
|
||||
if cache is None or cache["prompt"] == "" or cache["state"] is None:
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
else:
|
||||
@@ -511,6 +514,7 @@ def get_tokenizer(tokenizer_len: int):
|
||||
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
|
||||
webgpu = global_var.get(global_var.Args).webgpu
|
||||
|
||||
if "midi" in model.lower() or "abc" in model.lower():
|
||||
os.environ["RWKV_RESCALE_LAYER"] = "999"
|
||||
@@ -526,6 +530,11 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
||||
from rwkv_pip.cpp.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
elif webgpu:
|
||||
print("Using webgpu")
|
||||
from rwkv_pip.webgpu.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
else:
|
||||
from rwkv_pip.model import (
|
||||
RWKV as Model,
|
||||
|
||||
Reference in New Issue
Block a user