2023-05-28 23:52:38 +08:00
|
|
|
from typing import Any, Dict
|
|
|
|
from fastapi import APIRouter, HTTPException, Response, status
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import gc
|
|
|
|
import copy
|
2023-06-02 21:33:57 +08:00
|
|
|
import torch
|
2023-05-28 23:52:38 +08:00
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
trie = None
|
|
|
|
dtrie: Dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
def init():
|
|
|
|
global trie
|
|
|
|
try:
|
|
|
|
import cyac
|
|
|
|
import mmap
|
|
|
|
import os
|
|
|
|
|
|
|
|
if os.path.exists("state_cache.trie"):
|
|
|
|
with open("state_cache.trie", "r") as bf:
|
|
|
|
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
|
|
|
trie = cyac.Trie.from_buff(buff_object, copy=False)
|
|
|
|
else:
|
|
|
|
trie = cyac.Trie()
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
print("cyac not found")
|
|
|
|
|
|
|
|
|
|
|
|
class AddStateBody(BaseModel):
|
|
|
|
prompt: str
|
|
|
|
tokens: list[str]
|
|
|
|
state: Any
|
|
|
|
logits: Any
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/add-state")
|
|
|
|
def add_state(body: AddStateBody):
|
|
|
|
global trie, dtrie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
id = trie.insert(body.prompt)
|
2023-06-02 21:33:57 +08:00
|
|
|
device = body.state[0].device
|
2023-05-28 23:52:38 +08:00
|
|
|
dtrie[id] = {
|
|
|
|
"tokens": copy.deepcopy(body.tokens),
|
2023-06-02 21:33:57 +08:00
|
|
|
"state": [tensor.cpu() for tensor in body.state]
|
|
|
|
if device != torch.device("cpu")
|
|
|
|
else copy.deepcopy(body.state),
|
2023-05-28 23:52:38 +08:00
|
|
|
"logits": copy.deepcopy(body.logits),
|
2023-06-02 21:33:57 +08:00
|
|
|
"device": device,
|
2023-05-28 23:52:38 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return "success"
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/reset-state")
|
|
|
|
def reset_state():
|
|
|
|
global trie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
trie = cyac.Trie()
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
return "success"
|
|
|
|
|
|
|
|
|
|
|
|
class LongestPrefixStateBody(BaseModel):
|
|
|
|
prompt: str
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/longest-prefix-state")
|
|
|
|
def longest_prefix_state(body: LongestPrefixStateBody):
|
|
|
|
global trie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
id = -1
|
|
|
|
for id, len in trie.prefix(body.prompt):
|
|
|
|
pass
|
|
|
|
if id != -1:
|
|
|
|
v = dtrie[id]
|
2023-06-02 21:33:57 +08:00
|
|
|
device = v["device"]
|
2023-05-28 23:52:38 +08:00
|
|
|
return {
|
|
|
|
"prompt": trie[id],
|
|
|
|
"tokens": v["tokens"],
|
2023-06-02 21:33:57 +08:00
|
|
|
"state": [tensor.to(device) for tensor in v["state"]]
|
|
|
|
if device != torch.device("cpu")
|
|
|
|
else v["state"],
|
2023-05-28 23:52:38 +08:00
|
|
|
"logits": v["logits"],
|
2023-06-02 21:33:57 +08:00
|
|
|
"device": device,
|
2023-05-28 23:52:38 +08:00
|
|
|
}
|
|
|
|
else:
|
|
|
|
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/save-state")
|
|
|
|
def save_state():
|
|
|
|
global trie
|
|
|
|
if trie is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
|
|
|
|
trie.save("state_cache.trie")
|
|
|
|
|
|
|
|
return "success"
|