move state cache to memory (todo: state cache db)

This commit is contained in:
josc146
2023-06-02 21:33:57 +08:00
parent b63370928d
commit b41a2e7039
3 changed files with 13 additions and 2 deletions

View File

@@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Response, status
from pydantic import BaseModel
import gc
import copy
import torch
router = APIRouter()
@@ -41,10 +42,14 @@ def add_state(body: AddStateBody):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = trie.insert(body.prompt)
device = body.state[0].device
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": copy.deepcopy(body.state),
"state": [tensor.cpu() for tensor in body.state]
if device != torch.device("cpu")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"device": device,
}
return "success"
@@ -77,11 +82,15 @@ def longest_prefix_state(body: LongestPrefixStateBody):
pass
if id != -1:
v = dtrie[id]
device = v["device"]
return {
"prompt": trie[id],
"tokens": v["tokens"],
"state": v["state"],
"state": [tensor.to(device) for tensor in v["state"]]
if device != torch.device("cpu")
else v["state"],
"logits": v["logits"],
"device": device,
}
else:
return {"prompt": "", "tokens": [], "state": None, "logits": None}