diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index e7ce015..c58cf08 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -4,12 +4,16 @@ from fastapi import APIRouter, HTTPException, Request, Response, status from pydantic import BaseModel import gc import copy +import sys import torch router = APIRouter() trie = None dtrie: Dict = {} +max_trie_len = 3000 +loop_start_id = 1 # to prevent preloaded prompts from being deleted +loop_del_trie_id = loop_start_id def init(): @@ -39,10 +43,18 @@ class AddStateBody(BaseModel): @router.post("/add-state") def add_state(body: AddStateBody): - global trie, dtrie + global trie, dtrie, loop_del_trie_id if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + if len(trie) >= max_trie_len: + del_prompt = trie[loop_del_trie_id] + trie.remove(del_prompt) + dtrie[loop_del_trie_id] = None + loop_del_trie_id = loop_del_trie_id + 1 + if loop_del_trie_id >= max_trie_len: + loop_del_trie_id = loop_start_id + id = trie.insert(body.prompt) device = body.state[0].device dtrie[id] = { @@ -54,6 +66,12 @@ def add_state(body: AddStateBody): "device": device, } + quick_log( + None, + None, + f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}", + ) + return "success" @@ -74,6 +92,23 @@ class LongestPrefixStateBody(BaseModel): prompt: str +def _get_a_dtrie_buff_size(dtrie_v): + # print(sys.getsizeof(dtrie_v["tokens"][0])) # str + # print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"])) + # print(dtrie_v["state"][0][0].element_size()) + # print(dtrie_v["state"][0].nelement()) + # print(len(dtrie_v["state"])) + # print( + # len(dtrie_v["state"]) + # * dtrie_v["state"][0].nelement() + # * dtrie_v["state"][0][0].element_size() + # ) + # print(dtrie_v["logits"][0].element_size()) + # print(dtrie_v["logits"].nelement()) + # print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement()) + return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 + + @router.post("/longest-prefix-state") def longest_prefix_state(body: LongestPrefixStateBody, request: Request): global trie