fix the state cache crash caused by bad prompts

This commit is contained in:
josc146 2023-06-15 22:37:00 +08:00
parent d99488f22f
commit 721653a812

View File

@ -47,32 +47,36 @@ def add_state(body: AddStateBody):
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
if len(trie) >= max_trie_len: try:
del_prompt = trie[loop_del_trie_id] id = trie.insert(body.prompt)
trie.remove(del_prompt) device = body.state[0].device
dtrie[loop_del_trie_id] = None dtrie[id] = {
loop_del_trie_id = loop_del_trie_id + 1 "tokens": copy.deepcopy(body.tokens),
if loop_del_trie_id >= max_trie_len: "state": [tensor.cpu() for tensor in body.state]
loop_del_trie_id = loop_start_id if device != torch.device("cpu")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"device": device,
}
id = trie.insert(body.prompt) if len(trie) >= max_trie_len:
device = body.state[0].device del_prompt = trie[loop_del_trie_id]
dtrie[id] = { trie.remove(del_prompt)
"tokens": copy.deepcopy(body.tokens), dtrie[loop_del_trie_id] = None
"state": [tensor.cpu() for tensor in body.state] loop_del_trie_id = loop_del_trie_id + 1
if device != torch.device("cpu") if loop_del_trie_id >= max_trie_len:
else copy.deepcopy(body.state), loop_del_trie_id = loop_start_id
"logits": copy.deepcopy(body.logits),
"device": device,
}
quick_log( quick_log(
None, None,
None, None,
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}", f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
) )
return "success"
return "success" except Exception as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}"
)
@router.post("/reset-state") @router.post("/reset-state")
@ -116,7 +120,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = -1 id = -1
for id, len in trie.prefix(body.prompt): try:
for id, len in trie.prefix(body.prompt):
pass
except:
pass pass
if id != -1: if id != -1:
v = dtrie[id] v = dtrie[id]