267 lines
7.4 KiB
Python
267 lines
7.4 KiB
Python
from typing import Any, Dict, List, Union
|
|
from utils.log import quick_log
|
|
from fastapi import APIRouter, HTTPException, Request, Response, status
|
|
from pydantic import BaseModel
|
|
import gc
|
|
import copy
|
|
import global_var
|
|
|
|
router = APIRouter()
|
|
|
|
trie = None
|
|
dtrie: Dict = {}
|
|
max_trie_len = 300
|
|
loop_start_id = 1 # to prevent preloaded prompts from being deleted
|
|
loop_del_trie_id = loop_start_id
|
|
|
|
|
|
def init():
|
|
global trie
|
|
try:
|
|
import cyac
|
|
|
|
# import mmap
|
|
# import os
|
|
#
|
|
# if os.path.exists("state_cache.trie"):
|
|
# with open("state_cache.trie", "r") as bf:
|
|
# buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
|
# trie = cyac.Trie.from_buff(buff_object, copy=False)
|
|
# else:
|
|
trie = cyac.Trie()
|
|
except ModuleNotFoundError:
|
|
print("cyac not found")
|
|
|
|
|
|
@router.post("/disable-state-cache", tags=["State Cache"])
|
|
def disable_state_cache():
|
|
global trie, dtrie
|
|
|
|
if global_var.get(global_var.Deploy_Mode) is True:
|
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
|
|
|
trie = None
|
|
dtrie = {}
|
|
gc.collect()
|
|
|
|
print("state cache disabled")
|
|
return "success"
|
|
|
|
|
|
@router.post("/enable-state-cache", tags=["State Cache"])
|
|
def enable_state_cache():
|
|
global trie, dtrie
|
|
|
|
if global_var.get(global_var.Deploy_Mode) is True:
|
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
|
|
|
try:
|
|
import cyac
|
|
|
|
trie = cyac.Trie()
|
|
dtrie = {}
|
|
gc.collect()
|
|
|
|
print("state cache enabled")
|
|
return "success"
|
|
except ModuleNotFoundError:
|
|
print("state cache disabled")
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
|
|
|
|
|
|
class AddStateBody(BaseModel):
|
|
prompt: str
|
|
tokens: List[Union[str, int]]
|
|
state: Any
|
|
logits: Any
|
|
|
|
|
|
def copy_tensor_to_cpu(tensors):
|
|
import torch
|
|
import numpy as np
|
|
|
|
devices: List[torch.device] = []
|
|
copied: Union[Any, None] = None
|
|
|
|
tensors_type = type(tensors)
|
|
if tensors_type == list:
|
|
if hasattr(tensors[0], "device"): # torch state
|
|
devices = [tensor.device for tensor in tensors]
|
|
copied = [tensor.cpu() for tensor in tensors]
|
|
else: # WebGPU logits
|
|
copied = tensors
|
|
elif tensors_type == torch.Tensor: # torch logits
|
|
devices = [tensors.device]
|
|
copied = tensors.cpu()
|
|
elif tensors_type == np.ndarray: # rwkv.cpp
|
|
copied = tensors
|
|
else: # WebGPU state
|
|
copied = tensors.back()
|
|
|
|
return copied, devices
|
|
|
|
|
|
# @router.post("/add-state", tags=["State Cache"])
|
|
def add_state(body: AddStateBody):
|
|
global trie, dtrie, loop_del_trie_id
|
|
|
|
# if global_var.get(global_var.Deploy_Mode) is True:
|
|
# raise HTTPException(status.HTTP_403_FORBIDDEN)
|
|
|
|
if trie is None:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
try:
|
|
devices: List[torch.device] = []
|
|
logits_device: Union[torch.device, None] = None
|
|
state: Union[Any, None] = None
|
|
logits: Union[Any, None] = None
|
|
|
|
if body.state is not None:
|
|
state, devices = copy_tensor_to_cpu(body.state)
|
|
if body.logits is not None:
|
|
logits, logits_devices = copy_tensor_to_cpu(body.logits)
|
|
if len(logits_devices) > 0:
|
|
logits_device = logits_devices[0]
|
|
|
|
id: int = trie.insert(body.prompt)
|
|
dtrie[id] = {
|
|
"tokens": body.tokens,
|
|
"state": state,
|
|
"logits": logits,
|
|
"devices": devices,
|
|
"logits_device": logits_device,
|
|
}
|
|
|
|
if len(trie) >= max_trie_len:
|
|
del_prompt = trie[loop_del_trie_id]
|
|
trie.remove(del_prompt)
|
|
dtrie[loop_del_trie_id] = None
|
|
loop_del_trie_id = loop_del_trie_id + 1
|
|
if loop_del_trie_id >= max_trie_len:
|
|
loop_del_trie_id = loop_start_id
|
|
|
|
quick_log(
|
|
None,
|
|
None,
|
|
f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {__get_a_dtrie_buff_size(dtrie[id])}",
|
|
)
|
|
return "success"
|
|
except Exception as e:
|
|
print(e) # should not happen
|
|
raise HTTPException(
|
|
status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}"
|
|
)
|
|
|
|
|
|
@router.post("/reset-state", tags=["State Cache"])
|
|
def reset_state():
|
|
global trie, dtrie
|
|
|
|
if global_var.get(global_var.Deploy_Mode) is True:
|
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
|
|
|
if trie is None:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
import cyac
|
|
|
|
trie = cyac.Trie()
|
|
dtrie = {}
|
|
gc.collect()
|
|
|
|
return "success"
|
|
|
|
|
|
class LongestPrefixStateBody(BaseModel):
|
|
prompt: str
|
|
|
|
|
|
def __get_a_dtrie_buff_size(dtrie_v):
|
|
# print(sys.getsizeof(dtrie_v["tokens"][0])) # str
|
|
# print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
|
|
# print(dtrie_v["state"][0][0].element_size())
|
|
# print(dtrie_v["state"][0].nelement())
|
|
# print(len(dtrie_v["state"]))
|
|
# print(
|
|
# len(dtrie_v["state"])
|
|
# * dtrie_v["state"][0].nelement()
|
|
# * dtrie_v["state"][0][0].element_size()
|
|
# )
|
|
# print(dtrie_v["logits"][0].element_size())
|
|
# print(dtrie_v["logits"].nelement())
|
|
# print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
|
|
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
|
|
|
|
|
|
# @router.post("/longest-prefix-state", tags=["State Cache"])
|
|
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|
global trie
|
|
|
|
# if global_var.get(global_var.Deploy_Mode) is True:
|
|
# raise HTTPException(status.HTTP_403_FORBIDDEN)
|
|
|
|
if trie is None:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
id = -1
|
|
try:
|
|
for id, len in trie.prefix(body.prompt):
|
|
pass
|
|
except:
|
|
pass
|
|
if id != -1:
|
|
prompt: str = trie[id]
|
|
v = dtrie[id]
|
|
tokens: List[Union[str, int]] = copy.deepcopy(v["tokens"])
|
|
devices: List[torch.device] = v["devices"]
|
|
logits_device: Union[torch.device, None] = v["logits_device"]
|
|
state: Union[Any, None] = v["state"]
|
|
logits: Union[Any, None] = v["logits"]
|
|
|
|
if type(state) == list and hasattr(state[0], "device"): # torch
|
|
state = [
|
|
tensor.to(devices[i])
|
|
if devices[i] != torch.device("cpu")
|
|
else tensor.clone()
|
|
for i, tensor in enumerate(state)
|
|
]
|
|
logits = (
|
|
logits.to(logits_device)
|
|
if logits_device != torch.device("cpu")
|
|
else logits.clone()
|
|
)
|
|
else: # rwkv.cpp, WebGPU
|
|
logits = np.copy(logits)
|
|
|
|
quick_log(request, body, "Hit:\n" + prompt)
|
|
return {
|
|
"prompt": prompt,
|
|
"tokens": tokens,
|
|
"state": state,
|
|
"logits": logits,
|
|
}
|
|
else:
|
|
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
|
|
|
|
|
# @router.post("/save-state", tags=["State Cache"])
|
|
def save_state():
|
|
global trie
|
|
|
|
# if global_var.get(global_var.Deploy_Mode) is True:
|
|
# raise HTTPException(status.HTTP_403_FORBIDDEN)
|
|
|
|
if trie is None:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
|
|
|
# trie.save("state_cache.trie")
|
|
|
|
return "not implemented"
|