type
This commit is contained in:
parent
d32351c130
commit
377f71b16b
@ -48,8 +48,8 @@ def add_state(body: AddStateBody):
|
|||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
id = trie.insert(body.prompt)
|
id: int = trie.insert(body.prompt)
|
||||||
device = body.state[0].device
|
device: torch.device = body.state[0].device
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
"tokens": copy.deepcopy(body.tokens),
|
"tokens": copy.deepcopy(body.tokens),
|
||||||
"state": [tensor.cpu() for tensor in body.state]
|
"state": [tensor.cpu() for tensor in body.state]
|
||||||
@ -110,7 +110,7 @@ def _get_a_dtrie_buff_size(dtrie_v):
|
|||||||
# print(dtrie_v["logits"][0].element_size())
|
# print(dtrie_v["logits"][0].element_size())
|
||||||
# print(dtrie_v["logits"].nelement())
|
# print(dtrie_v["logits"].nelement())
|
||||||
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
|
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
|
||||||
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
|
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
|
||||||
|
|
||||||
|
|
||||||
@router.post("/longest-prefix-state")
|
@router.post("/longest-prefix-state")
|
||||||
@ -127,8 +127,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
pass
|
pass
|
||||||
if id != -1:
|
if id != -1:
|
||||||
v = dtrie[id]
|
v = dtrie[id]
|
||||||
device = v["device"]
|
device: torch.device = v["device"]
|
||||||
prompt = trie[id]
|
prompt: str = trie[id]
|
||||||
|
|
||||||
quick_log(request, body, "Hit:\n" + prompt)
|
quick_log(request, body, "Hit:\n" + prompt)
|
||||||
return {
|
return {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@ -137,7 +138,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
if device != torch.device("cpu")
|
if device != torch.device("cpu")
|
||||||
else v["state"],
|
else v["state"],
|
||||||
"logits": v["logits"],
|
"logits": v["logits"],
|
||||||
"device": device,
|
"device": device.type,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
|
Loading…
Reference in New Issue
Block a user