avoid misoperations of state_cache
This commit is contained in:
parent
fa0fcc2c89
commit
5990567a79
@ -16,14 +16,15 @@ def init():
|
|||||||
global trie
|
global trie
|
||||||
try:
|
try:
|
||||||
import cyac
|
import cyac
|
||||||
import mmap
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.path.exists("state_cache.trie"):
|
# import mmap
|
||||||
with open("state_cache.trie", "r") as bf:
|
# import os
|
||||||
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
#
|
||||||
trie = cyac.Trie.from_buff(buff_object, copy=False)
|
# if os.path.exists("state_cache.trie"):
|
||||||
else:
|
# with open("state_cache.trie", "r") as bf:
|
||||||
|
# buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
|
# trie = cyac.Trie.from_buff(buff_object, copy=False)
|
||||||
|
# else:
|
||||||
trie = cyac.Trie()
|
trie = cyac.Trie()
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
print("cyac not found")
|
print("cyac not found")
|
||||||
@ -58,11 +59,12 @@ def add_state(body: AddStateBody):
|
|||||||
|
|
||||||
@router.post("/reset-state")
|
@router.post("/reset-state")
|
||||||
def reset_state():
|
def reset_state():
|
||||||
global trie
|
global trie, dtrie
|
||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
trie = cyac.Trie()
|
trie = cyac.Trie()
|
||||||
|
dtrie = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
@ -96,7 +98,13 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
"device": device,
|
"device": device,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
return {
|
||||||
|
"prompt": "",
|
||||||
|
"tokens": [],
|
||||||
|
"state": None,
|
||||||
|
"logits": None,
|
||||||
|
"device": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/save-state")
|
@router.post("/save-state")
|
||||||
@ -105,6 +113,6 @@ def save_state():
|
|||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
trie.save("state_cache.trie")
|
# trie.save("state_cache.trie")
|
||||||
|
|
||||||
return "success"
|
return "not implemented"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user