better state cache
This commit is contained in:
@@ -44,6 +44,7 @@ def disable_state_cache():
|
||||
dtrie = {}
|
||||
gc.collect()
|
||||
|
||||
print("state cache disabled")
|
||||
return "success"
|
||||
|
||||
|
||||
@@ -61,8 +62,10 @@ def enable_state_cache():
|
||||
dtrie = {}
|
||||
gc.collect()
|
||||
|
||||
print("state cache enabled")
|
||||
return "success"
|
||||
except ModuleNotFoundError:
|
||||
print("state cache disabled")
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
|
||||
|
||||
|
||||
@@ -87,14 +90,12 @@ def add_state(body: AddStateBody):
|
||||
|
||||
try:
|
||||
id: int = trie.insert(body.prompt)
|
||||
device: torch.device = body.state[0].device
|
||||
devices: List[torch.device] = [tensor.device for tensor in body.state]
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": [tensor.cpu() for tensor in body.state]
|
||||
if device != torch.device("cpu")
|
||||
else copy.deepcopy(body.state),
|
||||
"state": [tensor.cpu() for tensor in body.state],
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"device": device,
|
||||
"devices": devices,
|
||||
}
|
||||
|
||||
if len(trie) >= max_trie_len:
|
||||
@@ -177,27 +178,18 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
pass
|
||||
if id != -1:
|
||||
v = dtrie[id]
|
||||
device: torch.device = v["device"]
|
||||
devices: List[torch.device] = v["devices"]
|
||||
prompt: str = trie[id]
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"tokens": v["tokens"],
|
||||
"state": [tensor.to(device) for tensor in v["state"]]
|
||||
if device != torch.device("cpu")
|
||||
else v["state"],
|
||||
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
|
||||
"logits": v["logits"],
|
||||
"device": device.type,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"prompt": "",
|
||||
"tokens": [],
|
||||
"state": None,
|
||||
"logits": None,
|
||||
"device": None,
|
||||
}
|
||||
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||
|
||||
|
||||
# @router.post("/save-state", tags=["State Cache"])
|
||||
|
||||
Reference in New Issue
Block a user