feat: use model state cache to achieve 5x - 50x faster preparation time for generation
This commit is contained in:
@@ -14,7 +14,7 @@ router = APIRouter()
|
||||
def get_tokens_path(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
default_tokens_path = (
|
||||
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
||||
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
|
||||
)
|
||||
if "raven" in model_path:
|
||||
return default_tokens_path
|
||||
|
||||
98
backend-python/routes/state_cache.py
Normal file
98
backend-python/routes/state_cache.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Any, Dict
|
||||
from fastapi import APIRouter, HTTPException, Response, status
|
||||
from pydantic import BaseModel
|
||||
import gc
|
||||
import copy
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
trie = None
|
||||
dtrie: Dict = {}
|
||||
|
||||
|
||||
def init():
|
||||
global trie
|
||||
try:
|
||||
import cyac
|
||||
import mmap
|
||||
import os
|
||||
|
||||
if os.path.exists("state_cache.trie"):
|
||||
with open("state_cache.trie", "r") as bf:
|
||||
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
trie = cyac.Trie.from_buff(buff_object, copy=False)
|
||||
else:
|
||||
trie = cyac.Trie()
|
||||
except ModuleNotFoundError:
|
||||
print("cyac not found")
|
||||
|
||||
|
||||
class AddStateBody(BaseModel):
|
||||
prompt: str
|
||||
tokens: list[str]
|
||||
state: Any
|
||||
logits: Any
|
||||
|
||||
|
||||
@router.post("/add-state")
|
||||
def add_state(body: AddStateBody):
|
||||
global trie, dtrie
|
||||
if trie is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
id = trie.insert(body.prompt)
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": copy.deepcopy(body.state),
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
}
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/reset-state")
|
||||
def reset_state():
|
||||
global trie
|
||||
if trie is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
trie = cyac.Trie()
|
||||
gc.collect()
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
class LongestPrefixStateBody(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
@router.post("/longest-prefix-state")
|
||||
def longest_prefix_state(body: LongestPrefixStateBody):
|
||||
global trie
|
||||
if trie is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
id = -1
|
||||
for id, len in trie.prefix(body.prompt):
|
||||
pass
|
||||
if id != -1:
|
||||
v = dtrie[id]
|
||||
return {
|
||||
"prompt": trie[id],
|
||||
"tokens": v["tokens"],
|
||||
"state": v["state"],
|
||||
"logits": v["logits"],
|
||||
}
|
||||
else:
|
||||
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||
|
||||
|
||||
@router.post("/save-state")
|
||||
def save_state():
|
||||
global trie
|
||||
if trie is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
trie.save("state_cache.trie")
|
||||
|
||||
return "success"
|
||||
Reference in New Issue
Block a user