RWKV-Runner/backend-python/routes/state_cache.py

234 lines
6.4 KiB
Python
Raw Normal View History

2023-08-13 21:27:29 +08:00
from typing import Any, Dict, List, Union
from utils.log import quick_log
from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
import global_var
router = APIRouter()
trie = None
dtrie: Dict = {}
2023-10-28 23:04:49 +08:00
max_trie_len = 300
2023-06-12 15:22:17 +08:00
loop_start_id = 1 # to prevent preloaded prompts from being deleted
loop_del_trie_id = loop_start_id
def init():
global trie
try:
import cyac
2023-06-12 12:32:50 +08:00
# import mmap
# import os
#
# if os.path.exists("state_cache.trie"):
# with open("state_cache.trie", "r") as bf:
# buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
# trie = cyac.Trie.from_buff(buff_object, copy=False)
# else:
trie = cyac.Trie()
except ModuleNotFoundError:
print("cyac not found")
2023-07-26 22:24:26 +08:00
@router.post("/disable-state-cache", tags=["State Cache"])
2023-07-11 11:20:12 +08:00
def disable_state_cache():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
2023-07-11 11:20:12 +08:00
trie = None
dtrie = {}
gc.collect()
2023-12-08 15:28:33 +08:00
print("state cache disabled")
2023-07-11 11:20:12 +08:00
return "success"
2023-07-26 22:24:26 +08:00
@router.post("/enable-state-cache", tags=["State Cache"])
2023-07-11 11:20:12 +08:00
def enable_state_cache():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
2023-07-11 11:20:12 +08:00
try:
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()
2023-12-08 15:28:33 +08:00
print("state cache enabled")
2023-07-11 11:20:12 +08:00
return "success"
except ModuleNotFoundError:
2023-12-08 15:28:33 +08:00
print("state cache disabled")
2023-07-11 11:20:12 +08:00
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
class AddStateBody(BaseModel):
prompt: str
2023-08-13 21:27:29 +08:00
tokens: List[Union[str, int]]
state: Any
logits: Any
2023-11-17 21:32:11 +08:00
# @router.post("/add-state", tags=["State Cache"])
def add_state(body: AddStateBody):
2023-06-12 15:22:17 +08:00
global trie, dtrie, loop_del_trie_id
2023-11-17 21:32:11 +08:00
# if global_var.get(global_var.Deploy_Mode) is True:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
2023-07-25 16:14:29 +08:00
import torch
import numpy as np
2023-07-25 16:14:29 +08:00
try:
devices: List[torch.device] = []
state: Union[Any, None] = None
if body.state is not None:
if type(body.state) == list or type(body.state) == np.ndarray:
devices = [
(
tensor.device
if hasattr(tensor, "device")
else torch.device("cpu")
)
for tensor in body.state
]
state = (
[tensor.cpu() for tensor in body.state]
if hasattr(body.state[0], "device")
else copy.deepcopy(body.state)
)
else:
pass # WebGPU
2023-06-19 22:32:02 +08:00
id: int = trie.insert(body.prompt)
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": state,
"logits": copy.deepcopy(body.logits),
2023-12-08 15:28:33 +08:00
"devices": devices,
}
2023-06-12 15:22:17 +08:00
if len(trie) >= max_trie_len:
del_prompt = trie[loop_del_trie_id]
trie.remove(del_prompt)
dtrie[loop_del_trie_id] = None
loop_del_trie_id = loop_del_trie_id + 1
if loop_del_trie_id >= max_trie_len:
loop_del_trie_id = loop_start_id
quick_log(
None,
None,
2023-08-24 22:48:54 +08:00
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {__get_a_dtrie_buff_size(dtrie[id])}",
)
return "success"
except Exception as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}"
)
2023-07-26 22:24:26 +08:00
@router.post("/reset-state", tags=["State Cache"])
def reset_state():
2023-06-12 12:32:50 +08:00
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
2023-07-11 11:20:12 +08:00
import cyac
trie = cyac.Trie()
2023-06-12 12:32:50 +08:00
dtrie = {}
gc.collect()
return "success"
class LongestPrefixStateBody(BaseModel):
prompt: str
2023-08-24 22:48:54 +08:00
def __get_a_dtrie_buff_size(dtrie_v):
2023-06-12 15:22:17 +08:00
# print(sys.getsizeof(dtrie_v["tokens"][0])) # str
# print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
# print(dtrie_v["state"][0][0].element_size())
# print(dtrie_v["state"][0].nelement())
# print(len(dtrie_v["state"]))
# print(
# len(dtrie_v["state"])
# * dtrie_v["state"][0].nelement()
# * dtrie_v["state"][0][0].element_size()
# )
# print(dtrie_v["logits"][0].element_size())
# print(dtrie_v["logits"].nelement())
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
2023-06-19 22:32:02 +08:00
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
2023-06-12 15:22:17 +08:00
2023-11-17 21:32:11 +08:00
# @router.post("/longest-prefix-state", tags=["State Cache"])
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie
2023-11-17 21:32:11 +08:00
# if global_var.get(global_var.Deploy_Mode) is True:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
2023-07-25 16:14:29 +08:00
import torch
import numpy as np
2023-07-25 16:14:29 +08:00
id = -1
try:
for id, len in trie.prefix(body.prompt):
pass
except:
pass
if id != -1:
v = dtrie[id]
2023-12-08 15:28:33 +08:00
devices: List[torch.device] = v["devices"]
2023-06-19 22:32:02 +08:00
prompt: str = trie[id]
state: Union[Any, None] = v["state"]
if state is not None and type(state) == list and hasattr(state[0], "device"):
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
2023-06-19 22:32:02 +08:00
2023-06-12 13:41:51 +08:00
quick_log(request, body, "Hit:\n" + prompt)
return {
"prompt": prompt,
"tokens": v["tokens"],
"state": state,
"logits": v["logits"],
}
else:
2023-12-08 15:28:33 +08:00
return {"prompt": "", "tokens": [], "state": None, "logits": None}
2023-11-17 21:32:11 +08:00
# @router.post("/save-state", tags=["State Cache"])
def save_state():
global trie
2023-11-17 21:32:11 +08:00
# if global_var.get(global_var.Deploy_Mode) is True:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
2023-06-12 12:32:50 +08:00
# trie.save("state_cache.trie")
2023-06-12 12:32:50 +08:00
return "not implemented"