type
This commit is contained in:
parent
d32351c130
commit
377f71b16b
@ -48,8 +48,8 @@ def add_state(body: AddStateBody):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
try:
|
||||
id = trie.insert(body.prompt)
|
||||
device = body.state[0].device
|
||||
id: int = trie.insert(body.prompt)
|
||||
device: torch.device = body.state[0].device
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": [tensor.cpu() for tensor in body.state]
|
||||
@ -110,7 +110,7 @@ def _get_a_dtrie_buff_size(dtrie_v):
|
||||
# print(dtrie_v["logits"][0].element_size())
|
||||
# print(dtrie_v["logits"].nelement())
|
||||
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
|
||||
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
|
||||
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
|
||||
|
||||
|
||||
@router.post("/longest-prefix-state")
|
||||
@ -127,8 +127,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
pass
|
||||
if id != -1:
|
||||
v = dtrie[id]
|
||||
device = v["device"]
|
||||
prompt = trie[id]
|
||||
device: torch.device = v["device"]
|
||||
prompt: str = trie[id]
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
@ -137,7 +138,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
if device != torch.device("cpu")
|
||||
else v["state"],
|
||||
"logits": v["logits"],
|
||||
"device": device,
|
||||
"device": device.type,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
|
Loading…
Reference in New Issue
Block a user