This commit is contained in:
josc146 2023-06-19 22:32:02 +08:00
parent d32351c130
commit 377f71b16b

View File

@ -48,8 +48,8 @@ def add_state(body: AddStateBody):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
try:
id = trie.insert(body.prompt)
device = body.state[0].device
id: int = trie.insert(body.prompt)
device: torch.device = body.state[0].device
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": [tensor.cpu() for tensor in body.state]
@ -110,7 +110,7 @@ def _get_a_dtrie_buff_size(dtrie_v):
# print(dtrie_v["logits"][0].element_size())
# print(dtrie_v["logits"].nelement())
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
@router.post("/longest-prefix-state")
@ -127,8 +127,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
pass
if id != -1:
v = dtrie[id]
device = v["device"]
prompt = trie[id]
device: torch.device = v["device"]
prompt: str = trie[id]
quick_log(request, body, "Hit:\n" + prompt)
return {
"prompt": prompt,
@ -137,7 +138,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
if device != torch.device("cpu")
else v["state"],
"logits": v["logits"],
"device": device,
"device": device.type,
}
else:
return {