feat: use model state cache to achieve 5x - 50x faster preparation time for generation

This commit is contained in:
josc146 2023-05-28 23:52:38 +08:00
parent 822f2d729c
commit 3e11128c9d
7 changed files with 160 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import cyac
import GPUtil
import torch
import rwkv

View File

@ -11,7 +11,7 @@ import uvicorn
from utils.rwkv import *
from utils.torch import *
from utils.ngrok import *
from routes import completion, config
from routes import completion, config, state_cache
import global_var
app = FastAPI()
@ -26,11 +26,13 @@ app.add_middleware(
app.include_router(completion.router)
app.include_router(config.router)
app.include_router(state_cache.router)
@app.on_event("startup")
def init():
global_var.init()
state_cache.init()
set_torch()

Binary file not shown.

View File

@ -14,7 +14,7 @@ router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
)
if "raven" in model_path:
return default_tokens_path

View File

@ -0,0 +1,98 @@
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Response, status
from pydantic import BaseModel
import gc
import copy
router = APIRouter()
trie = None
dtrie: Dict = {}
def init():
global trie
try:
import cyac
import mmap
import os
if os.path.exists("state_cache.trie"):
with open("state_cache.trie", "r") as bf:
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
trie = cyac.Trie.from_buff(buff_object, copy=False)
else:
trie = cyac.Trie()
except ModuleNotFoundError:
print("cyac not found")
class AddStateBody(BaseModel):
prompt: str
tokens: list[str]
state: Any
logits: Any
@router.post("/add-state")
def add_state(body: AddStateBody):
global trie, dtrie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = trie.insert(body.prompt)
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
}
return "success"
@router.post("/reset-state")
def reset_state():
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
trie = cyac.Trie()
gc.collect()
return "success"
class LongestPrefixStateBody(BaseModel):
prompt: str
@router.post("/longest-prefix-state")
def longest_prefix_state(body: LongestPrefixStateBody):
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = -1
for id, len in trie.prefix(body.prompt):
pass
if id != -1:
v = dtrie[id]
return {
"prompt": trie[id],
"tokens": v["tokens"],
"state": v["state"],
"logits": v["logits"],
}
else:
return {"prompt": "", "tokens": [], "state": None, "logits": None}
@router.post("/save-state")
def save_state():
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
trie.save("state_cache.trie")
return "success"

View File

@ -1,8 +1,11 @@
import os
import pathlib
import copy
from typing import Dict, List
from fastapi import HTTPException
from pydantic import BaseModel
from rwkv_pip.utils import PIPELINE
from routes import state_cache
END_OF_TEXT = 0
@ -61,9 +64,37 @@ class RWKV:
return out
def generate(self, prompt: str, stop: str = None):
cache = None
delta_prompt = prompt
try:
cache = state_cache.longest_prefix_state(
state_cache.LongestPrefixStateBody(prompt=prompt)
)
except HTTPException:
pass
if cache is None or cache["prompt"] == "":
self.model_state = None
self.model_tokens = []
logits = self.run_rnn(self.pipeline.encode(prompt))
else:
delta_prompt = prompt[len(cache["prompt"]) :]
self.model_state = copy.deepcopy(cache["state"])
self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"])
if delta_prompt != "":
logits = self.run_rnn(self.pipeline.encode(delta_prompt))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
begin = len(self.model_tokens)
out_last = begin
@ -94,9 +125,32 @@ class RWKV:
if stop is not None:
if stop in response:
response = response.split(stop)[0]
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
yield response, ""
break
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
yield response, delta