max_trie_len

This commit is contained in:
josc146 2023-06-12 15:22:17 +08:00
parent 8431b5d24f
commit 5896593951

View File

@ -4,12 +4,16 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
import sys
import torch
router = APIRouter()
trie = None
dtrie: Dict = {}
max_trie_len = 3000
loop_start_id = 1 # to prevent preloaded prompts from being deleted
loop_del_trie_id = loop_start_id
def init():
@ -39,10 +43,18 @@ class AddStateBody(BaseModel):
@router.post("/add-state")
def add_state(body: AddStateBody):
global trie, dtrie
global trie, dtrie, loop_del_trie_id
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
if len(trie) >= max_trie_len:
del_prompt = trie[loop_del_trie_id]
trie.remove(del_prompt)
dtrie[loop_del_trie_id] = None
loop_del_trie_id = loop_del_trie_id + 1
if loop_del_trie_id >= max_trie_len:
loop_del_trie_id = loop_start_id
id = trie.insert(body.prompt)
device = body.state[0].device
dtrie[id] = {
@ -54,6 +66,12 @@ def add_state(body: AddStateBody):
"device": device,
}
quick_log(
None,
None,
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
)
return "success"
@ -74,6 +92,23 @@ class LongestPrefixStateBody(BaseModel):
prompt: str
def _get_a_dtrie_buff_size(dtrie_v):
# print(sys.getsizeof(dtrie_v["tokens"][0])) # str
# print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
# print(dtrie_v["state"][0][0].element_size())
# print(dtrie_v["state"][0].nelement())
# print(len(dtrie_v["state"]))
# print(
# len(dtrie_v["state"])
# * dtrie_v["state"][0].nelement()
# * dtrie_v["state"][0][0].element_size()
# )
# print(dtrie_v["logits"][0].element_size())
# print(dtrie_v["logits"].nelement())
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
@router.post("/longest-prefix-state")
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie