diff --git a/backend-python/dep_check.py b/backend-python/dep_check.py index 87dde75..e9a17cf 100644 --- a/backend-python/dep_check.py +++ b/backend-python/dep_check.py @@ -1,3 +1,4 @@ +import cyac import GPUtil import torch import rwkv diff --git a/backend-python/main.py b/backend-python/main.py index 50d921f..311c042 100644 --- a/backend-python/main.py +++ b/backend-python/main.py @@ -11,7 +11,7 @@ import uvicorn from utils.rwkv import * from utils.torch import * from utils.ngrok import * -from routes import completion, config +from routes import completion, config, state_cache import global_var app = FastAPI() @@ -26,11 +26,13 @@ app.add_middleware( app.include_router(completion.router) app.include_router(config.router) +app.include_router(state_cache.router) @app.on_event("startup") def init(): global_var.init() + state_cache.init() set_torch() diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index 30857ab..5ab0abc 100644 Binary files a/backend-python/requirements.txt and b/backend-python/requirements.txt differ diff --git a/backend-python/requirements_versions.txt b/backend-python/requirements_versions.txt index 6d94b05..df9e09e 100644 Binary files a/backend-python/requirements_versions.txt and b/backend-python/requirements_versions.txt differ diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index 598c945..10b48d7 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -14,7 +14,7 @@ router = APIRouter() def get_tokens_path(model_path: str): model_path = model_path.lower() default_tokens_path = ( - f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json" + f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json" ) if "raven" in model_path: return default_tokens_path diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py new file mode 100644 index 0000000..16e5404 --- /dev/null +++ b/backend-python/routes/state_cache.py @@ -0,0 +1,98 @@ +from typing import Any, Dict +from fastapi import APIRouter, HTTPException, Response, status +from pydantic import BaseModel +import gc +import copy + +router = APIRouter() + +trie = None +dtrie: Dict = {} + + +def init(): + global trie + try: + import cyac + import mmap + import os + + if os.path.exists("state_cache.trie"): + with open("state_cache.trie", "r") as bf: + buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ) + trie = cyac.Trie.from_buff(buff_object, copy=False) + else: + trie = cyac.Trie() + except ModuleNotFoundError: + print("cyac not found") + + +class AddStateBody(BaseModel): + prompt: str + tokens: list[str] + state: Any + logits: Any + + +@router.post("/add-state") +def add_state(body: AddStateBody): + global trie, dtrie + if trie is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + + id = trie.insert(body.prompt) + dtrie[id] = { + "tokens": copy.deepcopy(body.tokens), + "state": copy.deepcopy(body.state), + "logits": copy.deepcopy(body.logits), + } + + return "success" + + +@router.post("/reset-state") +def reset_state(): + global trie + if trie is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + + trie = cyac.Trie() + gc.collect() + + return "success" + + +class LongestPrefixStateBody(BaseModel): + prompt: str + + +@router.post("/longest-prefix-state") +def longest_prefix_state(body: LongestPrefixStateBody): + global trie + if trie is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + + id = -1 + for id, len in trie.prefix(body.prompt): + pass + if id != -1: + v = dtrie[id] + return { + "prompt": trie[id], + "tokens": v["tokens"], + "state": v["state"], + "logits": v["logits"], + } + else: + return {"prompt": "", "tokens": [], "state": None, "logits": None} + + +@router.post("/save-state") +def save_state(): + global trie + if trie is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + + trie.save("state_cache.trie") + + return "success" diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index c2c7380..6af3c34 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -1,8 +1,11 @@ import os import pathlib +import copy from typing import Dict, List +from fastapi import HTTPException from pydantic import BaseModel from rwkv_pip.utils import PIPELINE +from routes import state_cache END_OF_TEXT = 0 @@ -61,9 +64,37 @@ class RWKV: return out def generate(self, prompt: str, stop: str = None): - self.model_state = None - self.model_tokens = [] - logits = self.run_rnn(self.pipeline.encode(prompt)) + cache = None + delta_prompt = prompt + try: + cache = state_cache.longest_prefix_state( + state_cache.LongestPrefixStateBody(prompt=prompt) + ) + except HTTPException: + pass + if cache is None or cache["prompt"] == "": + self.model_state = None + self.model_tokens = [] + else: + delta_prompt = prompt[len(cache["prompt"]) :] + self.model_state = copy.deepcopy(cache["state"]) + self.model_tokens = copy.deepcopy(cache["tokens"]) + logits = copy.deepcopy(cache["logits"]) + + if delta_prompt != "": + logits = self.run_rnn(self.pipeline.encode(delta_prompt)) + try: + state_cache.add_state( + state_cache.AddStateBody( + prompt=prompt, + tokens=self.model_tokens, + state=self.model_state, + logits=logits, + ) + ) + except HTTPException: + pass + begin = len(self.model_tokens) out_last = begin @@ -94,9 +125,32 @@ class RWKV: if stop is not None: if stop in response: response = response.split(stop)[0] + try: + state_cache.add_state( + state_cache.AddStateBody( + prompt=prompt + response, + tokens=self.model_tokens, + state=self.model_state, + logits=logits, + ) + ) + except HTTPException: + pass yield response, "" break out_last = begin + i + 1 + if i == self.max_tokens_per_generation - 1: + try: + state_cache.add_state( + state_cache.AddStateBody( + prompt=prompt + response, + tokens=self.model_tokens, + state=self.model_state, + logits=logits, + ) + ) + except HTTPException: + pass yield response, delta