avoid misoperations of state_cache

This commit is contained in:
josc146 2023-06-12 12:32:50 +08:00
parent fa0fcc2c89
commit 5990567a79

View File

@ -16,15 +16,16 @@ def init():
global trie
try:
import cyac
import mmap
import os
if os.path.exists("state_cache.trie"):
with open("state_cache.trie", "r") as bf:
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
trie = cyac.Trie.from_buff(buff_object, copy=False)
else:
trie = cyac.Trie()
# import mmap
# import os
#
# if os.path.exists("state_cache.trie"):
# with open("state_cache.trie", "r") as bf:
# buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
# trie = cyac.Trie.from_buff(buff_object, copy=False)
# else:
trie = cyac.Trie()
except ModuleNotFoundError:
print("cyac not found")
@ -58,11 +59,12 @@ def add_state(body: AddStateBody):
@router.post("/reset-state")
def reset_state():
global trie
global trie, dtrie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
trie = cyac.Trie()
dtrie = {}
gc.collect()
return "success"
@ -96,7 +98,13 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
"device": device,
}
else:
return {"prompt": "", "tokens": [], "state": None, "logits": None}
return {
"prompt": "",
"tokens": [],
"state": None,
"logits": None,
"device": None,
}
@router.post("/save-state")
@ -105,6 +113,6 @@ def save_state():
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
trie.save("state_cache.trie")
# trie.save("state_cache.trie")
return "success"
return "not implemented"