max_trie_len

This commit is contained in:
josc146 2023-06-12 15:22:17 +08:00
parent 8431b5d24f
commit 5896593951

View File

@ -4,12 +4,16 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel from pydantic import BaseModel
import gc import gc
import copy import copy
import sys
import torch import torch
router = APIRouter() router = APIRouter()
trie = None trie = None
dtrie: Dict = {} dtrie: Dict = {}
max_trie_len = 3000
loop_start_id = 1 # to prevent preloaded prompts from being deleted
loop_del_trie_id = loop_start_id
def init(): def init():
@ -39,10 +43,18 @@ class AddStateBody(BaseModel):
@router.post("/add-state") @router.post("/add-state")
def add_state(body: AddStateBody): def add_state(body: AddStateBody):
global trie, dtrie global trie, dtrie, loop_del_trie_id
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
if len(trie) >= max_trie_len:
del_prompt = trie[loop_del_trie_id]
trie.remove(del_prompt)
dtrie[loop_del_trie_id] = None
loop_del_trie_id = loop_del_trie_id + 1
if loop_del_trie_id >= max_trie_len:
loop_del_trie_id = loop_start_id
id = trie.insert(body.prompt) id = trie.insert(body.prompt)
device = body.state[0].device device = body.state[0].device
dtrie[id] = { dtrie[id] = {
@ -54,6 +66,12 @@ def add_state(body: AddStateBody):
"device": device, "device": device,
} }
quick_log(
None,
None,
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
)
return "success" return "success"
@ -74,6 +92,23 @@ class LongestPrefixStateBody(BaseModel):
prompt: str prompt: str
def _get_a_dtrie_buff_size(dtrie_v):
# print(sys.getsizeof(dtrie_v["tokens"][0])) # str
# print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
# print(dtrie_v["state"][0][0].element_size())
# print(dtrie_v["state"][0].nelement())
# print(len(dtrie_v["state"]))
# print(
# len(dtrie_v["state"])
# * dtrie_v["state"][0].nelement()
# * dtrie_v["state"][0][0].element_size()
# )
# print(dtrie_v["logits"][0].element_size())
# print(dtrie_v["logits"].nelement())
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
@router.post("/longest-prefix-state") @router.post("/longest-prefix-state")
def longest_prefix_state(body: LongestPrefixStateBody, request: Request): def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie global trie