This commit is contained in:
josc146
2023-12-14 18:37:07 +08:00
parent 01c95f5bc4
commit 0ddd2e9fea
16 changed files with 155 additions and 34 deletions

View File

@@ -8,7 +8,6 @@ from typing import Dict, Iterable, List, Tuple, Union, Type
from utils.log import quick_log
from fastapi import HTTPException
from pydantic import BaseModel, Field
import numpy as np
from routes import state_cache
import global_var
@@ -68,6 +67,8 @@ class AbstractRWKV(ABC):
pass
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
import numpy as np
if fast_mode:
embedding, token_len = self.__fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None
@@ -222,6 +223,8 @@ class AbstractRWKV(ABC):
def generate(
self, prompt: str, stop: Union[str, List[str], None] = None
) -> Iterable[Tuple[str, str, int, int]]:
import numpy as np
quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None
delta_prompt = prompt
@@ -231,7 +234,7 @@ class AbstractRWKV(ABC):
)
except HTTPException:
pass
if cache is None or cache["prompt"] == "":
if cache is None or cache["prompt"] == "" or cache["state"] is None:
self.model_state = None
self.model_tokens = []
else:
@@ -511,6 +514,7 @@ def get_tokenizer(tokenizer_len: int):
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
webgpu = global_var.get(global_var.Args).webgpu
if "midi" in model.lower() or "abc" in model.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
@@ -526,6 +530,11 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
from rwkv_pip.cpp.model import (
RWKV as Model,
)
elif webgpu:
print("Using webgpu")
from rwkv_pip.webgpu.model import (
RWKV as Model,
)
else:
from rwkv_pip.model import (
RWKV as Model,