feat: use model state cache to achieve 5x - 50x faster preparation time for generation

This commit is contained in:
josc146 2023-05-28 23:52:38 +08:00
parent 822f2d729c
commit 3e11128c9d
7 changed files with 160 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import cyac
import GPUtil import GPUtil
import torch import torch
import rwkv import rwkv

View File

@ -11,7 +11,7 @@ import uvicorn
from utils.rwkv import * from utils.rwkv import *
from utils.torch import * from utils.torch import *
from utils.ngrok import * from utils.ngrok import *
from routes import completion, config from routes import completion, config, state_cache
import global_var import global_var
app = FastAPI() app = FastAPI()
@ -26,11 +26,13 @@ app.add_middleware(
app.include_router(completion.router) app.include_router(completion.router)
app.include_router(config.router) app.include_router(config.router)
app.include_router(state_cache.router)
@app.on_event("startup") @app.on_event("startup")
def init(): def init():
global_var.init() global_var.init()
state_cache.init()
set_torch() set_torch()

Binary file not shown.

View File

@ -14,7 +14,7 @@ router = APIRouter()
def get_tokens_path(model_path: str): def get_tokens_path(model_path: str):
model_path = model_path.lower() model_path = model_path.lower()
default_tokens_path = ( default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json" f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
) )
if "raven" in model_path: if "raven" in model_path:
return default_tokens_path return default_tokens_path

View File

@ -0,0 +1,98 @@
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Response, status
from pydantic import BaseModel
import gc
import copy
router = APIRouter()
trie = None
dtrie: Dict = {}
def init():
global trie
try:
import cyac
import mmap
import os
if os.path.exists("state_cache.trie"):
with open("state_cache.trie", "r") as bf:
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
trie = cyac.Trie.from_buff(buff_object, copy=False)
else:
trie = cyac.Trie()
except ModuleNotFoundError:
print("cyac not found")
class AddStateBody(BaseModel):
prompt: str
tokens: list[str]
state: Any
logits: Any
@router.post("/add-state")
def add_state(body: AddStateBody):
global trie, dtrie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = trie.insert(body.prompt)
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
}
return "success"
@router.post("/reset-state")
def reset_state():
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
trie = cyac.Trie()
gc.collect()
return "success"
class LongestPrefixStateBody(BaseModel):
prompt: str
@router.post("/longest-prefix-state")
def longest_prefix_state(body: LongestPrefixStateBody):
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = -1
for id, len in trie.prefix(body.prompt):
pass
if id != -1:
v = dtrie[id]
return {
"prompt": trie[id],
"tokens": v["tokens"],
"state": v["state"],
"logits": v["logits"],
}
else:
return {"prompt": "", "tokens": [], "state": None, "logits": None}
@router.post("/save-state")
def save_state():
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
trie.save("state_cache.trie")
return "success"

View File

@ -1,8 +1,11 @@
import os import os
import pathlib import pathlib
import copy
from typing import Dict, List from typing import Dict, List
from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from rwkv_pip.utils import PIPELINE from rwkv_pip.utils import PIPELINE
from routes import state_cache
END_OF_TEXT = 0 END_OF_TEXT = 0
@ -61,9 +64,37 @@ class RWKV:
return out return out
def generate(self, prompt: str, stop: str = None): def generate(self, prompt: str, stop: str = None):
cache = None
delta_prompt = prompt
try:
cache = state_cache.longest_prefix_state(
state_cache.LongestPrefixStateBody(prompt=prompt)
)
except HTTPException:
pass
if cache is None or cache["prompt"] == "":
self.model_state = None self.model_state = None
self.model_tokens = [] self.model_tokens = []
logits = self.run_rnn(self.pipeline.encode(prompt)) else:
delta_prompt = prompt[len(cache["prompt"]) :]
self.model_state = copy.deepcopy(cache["state"])
self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"])
if delta_prompt != "":
logits = self.run_rnn(self.pipeline.encode(delta_prompt))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
begin = len(self.model_tokens) begin = len(self.model_tokens)
out_last = begin out_last = begin
@ -94,9 +125,32 @@ class RWKV:
if stop is not None: if stop is not None:
if stop in response: if stop in response:
response = response.split(stop)[0] response = response.split(stop)[0]
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
yield response, "" yield response, ""
break break
out_last = begin + i + 1 out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
yield response, delta yield response, delta