fix the state cache crash caused by bad prompts
This commit is contained in:
parent
d99488f22f
commit
721653a812
@ -47,32 +47,36 @@ def add_state(body: AddStateBody):
|
|||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
if len(trie) >= max_trie_len:
|
try:
|
||||||
del_prompt = trie[loop_del_trie_id]
|
id = trie.insert(body.prompt)
|
||||||
trie.remove(del_prompt)
|
device = body.state[0].device
|
||||||
dtrie[loop_del_trie_id] = None
|
dtrie[id] = {
|
||||||
loop_del_trie_id = loop_del_trie_id + 1
|
"tokens": copy.deepcopy(body.tokens),
|
||||||
if loop_del_trie_id >= max_trie_len:
|
"state": [tensor.cpu() for tensor in body.state]
|
||||||
loop_del_trie_id = loop_start_id
|
if device != torch.device("cpu")
|
||||||
|
else copy.deepcopy(body.state),
|
||||||
|
"logits": copy.deepcopy(body.logits),
|
||||||
|
"device": device,
|
||||||
|
}
|
||||||
|
|
||||||
id = trie.insert(body.prompt)
|
if len(trie) >= max_trie_len:
|
||||||
device = body.state[0].device
|
del_prompt = trie[loop_del_trie_id]
|
||||||
dtrie[id] = {
|
trie.remove(del_prompt)
|
||||||
"tokens": copy.deepcopy(body.tokens),
|
dtrie[loop_del_trie_id] = None
|
||||||
"state": [tensor.cpu() for tensor in body.state]
|
loop_del_trie_id = loop_del_trie_id + 1
|
||||||
if device != torch.device("cpu")
|
if loop_del_trie_id >= max_trie_len:
|
||||||
else copy.deepcopy(body.state),
|
loop_del_trie_id = loop_start_id
|
||||||
"logits": copy.deepcopy(body.logits),
|
|
||||||
"device": device,
|
|
||||||
}
|
|
||||||
|
|
||||||
quick_log(
|
quick_log(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
|
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
|
||||||
)
|
)
|
||||||
|
return "success"
|
||||||
return "success"
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reset-state")
|
@router.post("/reset-state")
|
||||||
@ -116,7 +120,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
id = -1
|
id = -1
|
||||||
for id, len in trie.prefix(body.prompt):
|
try:
|
||||||
|
for id, len in trie.prefix(body.prompt):
|
||||||
|
pass
|
||||||
|
except:
|
||||||
pass
|
pass
|
||||||
if id != -1:
|
if id != -1:
|
||||||
v = dtrie[id]
|
v = dtrie[id]
|
||||||
|
Loading…
Reference in New Issue
Block a user