2023-06-12 12:09:23 +08:00
|
|
|
from typing import Any, Dict, List
|
2023-06-09 20:46:19 +08:00
|
|
|
from utils.log import quick_log
|
|
|
|
from fastapi import APIRouter, HTTPException, Request, Response, status
|
2023-05-28 23:52:38 +08:00
|
|
|
from pydantic import BaseModel
|
|
|
|
import gc
|
|
|
|
import copy
|
2023-06-02 21:33:57 +08:00
|
|
|
import torch
|
2023-05-28 23:52:38 +08:00
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
trie = None
|
|
|
|
dtrie: Dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
def init():
|
|
|
|
global trie
|
|
|
|
try:
|
|
|
|
import cyac
|
2023-06-12 12:32:50 +08:00
|
|
|
|
|
|
|
# import mmap
|
|
|
|
# import os
|
|
|
|
#
|
|
|
|
# if os.path.exists("state_cache.trie"):
|
|
|
|
# with open("state_cache.trie", "r") as bf:
|
|
|
|
# buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
|
|
|
# trie = cyac.Trie.from_buff(buff_object, copy=False)
|
|
|
|
# else:
|
|
|
|
trie = cyac.Trie()
|
2023-05-28 23:52:38 +08:00
|
|
|
except ModuleNotFoundError:
|
|
|
|
print("cyac not found")
|
|
|
|
|
|
|
|
|
|
|
|
class AddStateBody(BaseModel):
|
|
|
|
prompt: str
|
2023-06-12 12:09:23 +08:00
|
|
|
tokens: List[str]
|
2023-05-28 23:52:38 +08:00
|
|
|
state: Any
|
|
|
|
logits: Any
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/add-state")
|
|
|
|
def add_state(body: AddStateBody):
|
|
|
|
global trie, dtrie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
id = trie.insert(body.prompt)
|
2023-06-02 21:33:57 +08:00
|
|
|
device = body.state[0].device
|
2023-05-28 23:52:38 +08:00
|
|
|
dtrie[id] = {
|
|
|
|
"tokens": copy.deepcopy(body.tokens),
|
2023-06-02 21:33:57 +08:00
|
|
|
"state": [tensor.cpu() for tensor in body.state]
|
|
|
|
if device != torch.device("cpu")
|
|
|
|
else copy.deepcopy(body.state),
|
2023-05-28 23:52:38 +08:00
|
|
|
"logits": copy.deepcopy(body.logits),
|
2023-06-02 21:33:57 +08:00
|
|
|
"device": device,
|
2023-05-28 23:52:38 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return "success"
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/reset-state")
|
|
|
|
def reset_state():
|
2023-06-12 12:32:50 +08:00
|
|
|
global trie, dtrie
|
2023-05-28 23:52:38 +08:00
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
trie = cyac.Trie()
|
2023-06-12 12:32:50 +08:00
|
|
|
dtrie = {}
|
2023-05-28 23:52:38 +08:00
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
return "success"
|
|
|
|
|
|
|
|
|
|
|
|
class LongestPrefixStateBody(BaseModel):
|
|
|
|
prompt: str
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/longest-prefix-state")
|
2023-06-09 20:46:19 +08:00
|
|
|
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
2023-05-28 23:52:38 +08:00
|
|
|
global trie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
id = -1
|
|
|
|
for id, len in trie.prefix(body.prompt):
|
|
|
|
pass
|
|
|
|
if id != -1:
|
|
|
|
v = dtrie[id]
|
2023-06-02 21:33:57 +08:00
|
|
|
device = v["device"]
|
2023-06-09 20:46:19 +08:00
|
|
|
prompt = trie[id]
|
|
|
|
quick_log(request, body, "Hit: " + prompt)
|
2023-05-28 23:52:38 +08:00
|
|
|
return {
|
2023-06-09 20:46:19 +08:00
|
|
|
"prompt": prompt,
|
2023-05-28 23:52:38 +08:00
|
|
|
"tokens": v["tokens"],
|
2023-06-02 21:33:57 +08:00
|
|
|
"state": [tensor.to(device) for tensor in v["state"]]
|
|
|
|
if device != torch.device("cpu")
|
|
|
|
else v["state"],
|
2023-05-28 23:52:38 +08:00
|
|
|
"logits": v["logits"],
|
2023-06-02 21:33:57 +08:00
|
|
|
"device": device,
|
2023-05-28 23:52:38 +08:00
|
|
|
}
|
|
|
|
else:
|
2023-06-12 12:32:50 +08:00
|
|
|
return {
|
|
|
|
"prompt": "",
|
|
|
|
"tokens": [],
|
|
|
|
"state": None,
|
|
|
|
"logits": None,
|
|
|
|
"device": None,
|
|
|
|
}
|
2023-05-28 23:52:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
@router.post("/save-state")
|
|
|
|
def save_state():
|
|
|
|
global trie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
2023-06-12 12:32:50 +08:00
|
|
|
# trie.save("state_cache.trie")
|
2023-05-28 23:52:38 +08:00
|
|
|
|
2023-06-12 12:32:50 +08:00
|
|
|
return "not implemented"
|