diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index a150419..11d7026 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -48,8 +48,8 @@ def add_state(body: AddStateBody): raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") try: - id = trie.insert(body.prompt) - device = body.state[0].device + id: int = trie.insert(body.prompt) + device: torch.device = body.state[0].device dtrie[id] = { "tokens": copy.deepcopy(body.tokens), "state": [tensor.cpu() for tensor in body.state] @@ -110,7 +110,7 @@ def _get_a_dtrie_buff_size(dtrie_v): # print(dtrie_v["logits"][0].element_size()) # print(dtrie_v["logits"].nelement()) # print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement()) - return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 + return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO @router.post("/longest-prefix-state") @@ -127,8 +127,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): pass if id != -1: v = dtrie[id] - device = v["device"] - prompt = trie[id] + device: torch.device = v["device"] + prompt: str = trie[id] + quick_log(request, body, "Hit:\n" + prompt) return { "prompt": prompt, @@ -137,7 +138,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): if device != torch.device("cpu") else v["state"], "logits": v["logits"], - "device": device, + "device": device.type, } else: return {