move state cache to memory (todo: state cache db)

This commit is contained in:
josc146 2023-06-02 21:33:57 +08:00
parent b63370928d
commit b41a2e7039
3 changed files with 13 additions and 2 deletions

View File

@ -73,6 +73,7 @@ body.json:
- [x] CUDA operator int8 acceleration
- [ ] macOS support
- [ ] Linux support
- [ ] Local State Cache DB
## Related Repositories:

View File

@ -73,6 +73,7 @@ body.json:
- [x] CUDA算子int8提速
- [ ] macOS支持
- [ ] linux支持
- [ ] 本地状态缓存数据库
## 相关仓库:

View File

@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Response, status
from pydantic import BaseModel
import gc
import copy
import torch
router = APIRouter()
@ -41,10 +42,14 @@ def add_state(body: AddStateBody):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
id = trie.insert(body.prompt)
device = body.state[0].device
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": copy.deepcopy(body.state),
"state": [tensor.cpu() for tensor in body.state]
if device != torch.device("cpu")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"device": device,
}
return "success"
@ -77,11 +82,15 @@ def longest_prefix_state(body: LongestPrefixStateBody):
pass
if id != -1:
v = dtrie[id]
device = v["device"]
return {
"prompt": trie[id],
"tokens": v["tokens"],
"state": v["state"],
"state": [tensor.to(device) for tensor in v["state"]]
if device != torch.device("cpu")
else v["state"],
"logits": v["logits"],
"device": device,
}
else:
return {"prompt": "", "tokens": [], "state": None, "logits": None}