fix the issue where state cache could be modified leading to inconsistent hit results
This commit is contained in:
parent
e3baa0da86
commit
c9513822c9
@ -76,6 +76,31 @@ class AddStateBody(BaseModel):
|
|||||||
logits: Any
|
logits: Any
|
||||||
|
|
||||||
|
|
||||||
|
def copy_tensor_to_cpu(tensors):
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
devices: List[torch.device] = []
|
||||||
|
copied: Union[Any, None] = None
|
||||||
|
|
||||||
|
tensors_type = type(tensors)
|
||||||
|
if tensors_type == list:
|
||||||
|
if hasattr(tensors[0], "device"): # torch state
|
||||||
|
devices = [tensor.device for tensor in tensors]
|
||||||
|
copied = [tensor.cpu() for tensor in tensors]
|
||||||
|
else: # WebGPU logits
|
||||||
|
copied = tensors
|
||||||
|
elif tensors_type == torch.Tensor: # torch logits
|
||||||
|
devices = [tensors.device]
|
||||||
|
copied = tensors.cpu()
|
||||||
|
elif tensors_type == np.ndarray: # rwkv.cpp
|
||||||
|
copied = tensors
|
||||||
|
else: # WebGPU state
|
||||||
|
copied = tensors.back()
|
||||||
|
|
||||||
|
return copied, devices
|
||||||
|
|
||||||
|
|
||||||
# @router.post("/add-state", tags=["State Cache"])
|
# @router.post("/add-state", tags=["State Cache"])
|
||||||
def add_state(body: AddStateBody):
|
def add_state(body: AddStateBody):
|
||||||
global trie, dtrie, loop_del_trie_id
|
global trie, dtrie, loop_del_trie_id
|
||||||
@ -91,23 +116,24 @@ def add_state(body: AddStateBody):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
devices: List[torch.device] = []
|
devices: List[torch.device] = []
|
||||||
|
logits_device: Union[torch.device, None] = None
|
||||||
state: Union[Any, None] = None
|
state: Union[Any, None] = None
|
||||||
|
logits: Union[Any, None] = None
|
||||||
|
|
||||||
if body.state is not None:
|
if body.state is not None:
|
||||||
if type(body.state) == list and hasattr(body.state[0], "device"): # torch
|
state, devices = copy_tensor_to_cpu(body.state)
|
||||||
devices = [tensor.device for tensor in body.state]
|
if body.logits is not None:
|
||||||
state = [tensor.cpu() for tensor in body.state]
|
logits, logits_devices = copy_tensor_to_cpu(body.logits)
|
||||||
elif type(body.state) == np.ndarray: # rwkv.cpp
|
if len(logits_devices) > 0:
|
||||||
state = body.state
|
logits_device = logits_devices[0]
|
||||||
else: # WebGPU
|
|
||||||
state = body.state.back()
|
|
||||||
|
|
||||||
id: int = trie.insert(body.prompt)
|
id: int = trie.insert(body.prompt)
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
"tokens": body.tokens,
|
"tokens": body.tokens,
|
||||||
"state": state,
|
"state": state,
|
||||||
"logits": body.logits,
|
"logits": logits,
|
||||||
"devices": devices,
|
"devices": devices,
|
||||||
|
"logits_device": logits_device,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(trie) >= max_trie_len:
|
if len(trie) >= max_trie_len:
|
||||||
@ -125,6 +151,7 @@ def add_state(body: AddStateBody):
|
|||||||
)
|
)
|
||||||
return "success"
|
return "success"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e) # should not happen
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}"
|
status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}"
|
||||||
)
|
)
|
||||||
@ -192,18 +219,33 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
if id != -1:
|
if id != -1:
|
||||||
prompt: str = trie[id]
|
prompt: str = trie[id]
|
||||||
v = dtrie[id]
|
v = dtrie[id]
|
||||||
|
tokens: List[Union[str, int]] = copy.deepcopy(v["tokens"])
|
||||||
devices: List[torch.device] = v["devices"]
|
devices: List[torch.device] = v["devices"]
|
||||||
|
logits_device: Union[torch.device, None] = v["logits_device"]
|
||||||
state: Union[Any, None] = v["state"]
|
state: Union[Any, None] = v["state"]
|
||||||
|
logits: Union[Any, None] = v["logits"]
|
||||||
|
|
||||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||||
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
state = [
|
||||||
|
tensor.to(devices[i])
|
||||||
|
if devices[i] != torch.device("cpu")
|
||||||
|
else tensor.clone()
|
||||||
|
for i, tensor in enumerate(state)
|
||||||
|
]
|
||||||
|
logits = (
|
||||||
|
logits.to(logits_device)
|
||||||
|
if logits_device != torch.device("cpu")
|
||||||
|
else logits.clone()
|
||||||
|
)
|
||||||
|
else: # rwkv.cpp, WebGPU
|
||||||
|
logits = np.copy(logits)
|
||||||
|
|
||||||
quick_log(request, body, "Hit:\n" + prompt)
|
quick_log(request, body, "Hit:\n" + prompt)
|
||||||
return {
|
return {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"tokens": v["tokens"],
|
"tokens": tokens,
|
||||||
"state": state,
|
"state": state,
|
||||||
"logits": v["logits"],
|
"logits": logits,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||||
|
Loading…
Reference in New Issue
Block a user