From 721653a81237737b4d79855c6cfca5fd4976f996 Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 15 Jun 2023 22:37:00 +0800 Subject: [PATCH] fix the state cache crash caused by bad prompts --- backend-python/routes/state_cache.py | 57 ++++++++++++++++------------ 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index c58cf08..a150419 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -47,32 +47,36 @@ def add_state(body: AddStateBody): if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") - if len(trie) >= max_trie_len: - del_prompt = trie[loop_del_trie_id] - trie.remove(del_prompt) - dtrie[loop_del_trie_id] = None - loop_del_trie_id = loop_del_trie_id + 1 - if loop_del_trie_id >= max_trie_len: - loop_del_trie_id = loop_start_id + try: + id = trie.insert(body.prompt) + device = body.state[0].device + dtrie[id] = { + "tokens": copy.deepcopy(body.tokens), + "state": [tensor.cpu() for tensor in body.state] + if device != torch.device("cpu") + else copy.deepcopy(body.state), + "logits": copy.deepcopy(body.logits), + "device": device, + } - id = trie.insert(body.prompt) - device = body.state[0].device - dtrie[id] = { - "tokens": copy.deepcopy(body.tokens), - "state": [tensor.cpu() for tensor in body.state] - if device != torch.device("cpu") - else copy.deepcopy(body.state), - "logits": copy.deepcopy(body.logits), - "device": device, - } + if len(trie) >= max_trie_len: + del_prompt = trie[loop_del_trie_id] + trie.remove(del_prompt) + dtrie[loop_del_trie_id] = None + loop_del_trie_id = loop_del_trie_id + 1 + if loop_del_trie_id >= max_trie_len: + loop_del_trie_id = loop_start_id - quick_log( - None, - None, - f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}", - ) - - return "success" + quick_log( + None, + None, + f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}", + ) + return "success" + except Exception as e: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}" + ) @router.post("/reset-state") @@ -116,7 +120,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") id = -1 - for id, len in trie.prefix(body.prompt): + try: + for id, len in trie.prefix(body.prompt): + pass + except: pass if id != -1: v = dtrie[id]