max_trie_len
This commit is contained in:
parent
8431b5d24f
commit
5896593951
@ -4,12 +4,16 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import gc
|
import gc
|
||||||
import copy
|
import copy
|
||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
trie = None
|
trie = None
|
||||||
dtrie: Dict = {}
|
dtrie: Dict = {}
|
||||||
|
max_trie_len = 3000
|
||||||
|
loop_start_id = 1 # to prevent preloaded prompts from being deleted
|
||||||
|
loop_del_trie_id = loop_start_id
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
@ -39,10 +43,18 @@ class AddStateBody(BaseModel):
|
|||||||
|
|
||||||
@router.post("/add-state")
|
@router.post("/add-state")
|
||||||
def add_state(body: AddStateBody):
|
def add_state(body: AddStateBody):
|
||||||
global trie, dtrie
|
global trie, dtrie, loop_del_trie_id
|
||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
if len(trie) >= max_trie_len:
|
||||||
|
del_prompt = trie[loop_del_trie_id]
|
||||||
|
trie.remove(del_prompt)
|
||||||
|
dtrie[loop_del_trie_id] = None
|
||||||
|
loop_del_trie_id = loop_del_trie_id + 1
|
||||||
|
if loop_del_trie_id >= max_trie_len:
|
||||||
|
loop_del_trie_id = loop_start_id
|
||||||
|
|
||||||
id = trie.insert(body.prompt)
|
id = trie.insert(body.prompt)
|
||||||
device = body.state[0].device
|
device = body.state[0].device
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
@ -54,6 +66,12 @@ def add_state(body: AddStateBody):
|
|||||||
"device": device,
|
"device": device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
quick_log(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
|
||||||
|
)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
@ -74,6 +92,23 @@ class LongestPrefixStateBody(BaseModel):
|
|||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
def _get_a_dtrie_buff_size(dtrie_v):
|
||||||
|
# print(sys.getsizeof(dtrie_v["tokens"][0])) # str
|
||||||
|
# print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
|
||||||
|
# print(dtrie_v["state"][0][0].element_size())
|
||||||
|
# print(dtrie_v["state"][0].nelement())
|
||||||
|
# print(len(dtrie_v["state"]))
|
||||||
|
# print(
|
||||||
|
# len(dtrie_v["state"])
|
||||||
|
# * dtrie_v["state"][0].nelement()
|
||||||
|
# * dtrie_v["state"][0][0].element_size()
|
||||||
|
# )
|
||||||
|
# print(dtrie_v["logits"][0].element_size())
|
||||||
|
# print(dtrie_v["logits"].nelement())
|
||||||
|
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
|
||||||
|
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
|
||||||
|
|
||||||
|
|
||||||
@router.post("/longest-prefix-state")
|
@router.post("/longest-prefix-state")
|
||||||
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||||
global trie
|
global trie
|
||||||
|
Loading…
Reference in New Issue
Block a user