from typing import Any, Dict from fastapi import APIRouter, HTTPException, Response, status from pydantic import BaseModel import gc import copy import torch router = APIRouter() trie = None dtrie: Dict = {} def init(): global trie try: import cyac import mmap import os if os.path.exists("state_cache.trie"): with open("state_cache.trie", "r") as bf: buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ) trie = cyac.Trie.from_buff(buff_object, copy=False) else: trie = cyac.Trie() except ModuleNotFoundError: print("cyac not found") class AddStateBody(BaseModel): prompt: str tokens: list[str] state: Any logits: Any @router.post("/add-state") def add_state(body: AddStateBody): global trie, dtrie if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") id = trie.insert(body.prompt) device = body.state[0].device dtrie[id] = { "tokens": copy.deepcopy(body.tokens), "state": [tensor.cpu() for tensor in body.state] if device != torch.device("cpu") else copy.deepcopy(body.state), "logits": copy.deepcopy(body.logits), "device": device, } return "success" @router.post("/reset-state") def reset_state(): global trie if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") trie = cyac.Trie() gc.collect() return "success" class LongestPrefixStateBody(BaseModel): prompt: str @router.post("/longest-prefix-state") def longest_prefix_state(body: LongestPrefixStateBody): global trie if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") id = -1 for id, len in trie.prefix(body.prompt): pass if id != -1: v = dtrie[id] device = v["device"] return { "prompt": trie[id], "tokens": v["tokens"], "state": [tensor.to(device) for tensor in v["state"]] if device != torch.device("cpu") else v["state"], "logits": v["logits"], "device": device, } else: return {"prompt": "", "tokens": [], "state": None, "logits": None} @router.post("/save-state") def save_state(): global trie if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") trie.save("state_cache.trie") return "success"