move state cache to memory (todo: state cache db)
This commit is contained in:
parent
b63370928d
commit
b41a2e7039
@ -73,6 +73,7 @@ body.json:
|
|||||||
- [x] CUDA operator int8 acceleration
|
- [x] CUDA operator int8 acceleration
|
||||||
- [ ] macOS support
|
- [ ] macOS support
|
||||||
- [ ] Linux support
|
- [ ] Linux support
|
||||||
|
- [ ] Local State Cache DB
|
||||||
|
|
||||||
## Related Repositories:
|
## Related Repositories:
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ body.json:
|
|||||||
- [x] CUDA算子int8提速
|
- [x] CUDA算子int8提速
|
||||||
- [ ] macOS支持
|
- [ ] macOS支持
|
||||||
- [ ] linux支持
|
- [ ] linux支持
|
||||||
|
- [ ] 本地状态缓存数据库
|
||||||
|
|
||||||
## 相关仓库:
|
## 相关仓库:
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Response, status
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import gc
|
import gc
|
||||||
import copy
|
import copy
|
||||||
|
import torch
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -41,10 +42,14 @@ def add_state(body: AddStateBody):
|
|||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
id = trie.insert(body.prompt)
|
id = trie.insert(body.prompt)
|
||||||
|
device = body.state[0].device
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
"tokens": copy.deepcopy(body.tokens),
|
"tokens": copy.deepcopy(body.tokens),
|
||||||
"state": copy.deepcopy(body.state),
|
"state": [tensor.cpu() for tensor in body.state]
|
||||||
|
if device != torch.device("cpu")
|
||||||
|
else copy.deepcopy(body.state),
|
||||||
"logits": copy.deepcopy(body.logits),
|
"logits": copy.deepcopy(body.logits),
|
||||||
|
"device": device,
|
||||||
}
|
}
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
@ -77,11 +82,15 @@ def longest_prefix_state(body: LongestPrefixStateBody):
|
|||||||
pass
|
pass
|
||||||
if id != -1:
|
if id != -1:
|
||||||
v = dtrie[id]
|
v = dtrie[id]
|
||||||
|
device = v["device"]
|
||||||
return {
|
return {
|
||||||
"prompt": trie[id],
|
"prompt": trie[id],
|
||||||
"tokens": v["tokens"],
|
"tokens": v["tokens"],
|
||||||
"state": v["state"],
|
"state": [tensor.to(device) for tensor in v["state"]]
|
||||||
|
if device != torch.device("cpu")
|
||||||
|
else v["state"],
|
||||||
"logits": v["logits"],
|
"logits": v["logits"],
|
||||||
|
"device": device,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||||
|
Loading…
Reference in New Issue
Block a user