improve state cache performance
This commit is contained in:
@@ -94,28 +94,19 @@ def add_state(body: AddStateBody):
|
||||
state: Union[Any, None] = None
|
||||
|
||||
if body.state is not None:
|
||||
if type(body.state) == list or type(body.state) == np.ndarray:
|
||||
devices = [
|
||||
(
|
||||
tensor.device
|
||||
if hasattr(tensor, "device")
|
||||
else torch.device("cpu")
|
||||
)
|
||||
for tensor in body.state
|
||||
]
|
||||
state = (
|
||||
[tensor.cpu() for tensor in body.state]
|
||||
if hasattr(body.state[0], "device")
|
||||
else copy.deepcopy(body.state)
|
||||
)
|
||||
else:
|
||||
state = body.state.back() # WebGPU
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
devices = [tensor.device for tensor in body.state]
|
||||
state = [tensor.cpu() for tensor in body.state]
|
||||
elif type(body.state) == np.ndarray: # rwkv.cpp
|
||||
state = body.state
|
||||
else: # WebGPU
|
||||
state = body.state.back()
|
||||
|
||||
id: int = trie.insert(body.prompt)
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"tokens": body.tokens,
|
||||
"state": state,
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"logits": body.logits,
|
||||
"devices": devices,
|
||||
}
|
||||
|
||||
@@ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
except:
|
||||
pass
|
||||
if id != -1:
|
||||
prompt: str = trie[id]
|
||||
v = dtrie[id]
|
||||
devices: List[torch.device] = v["devices"]
|
||||
prompt: str = trie[id]
|
||||
state: Union[Any, None] = v["state"]
|
||||
|
||||
if state is not None and type(state) == list and hasattr(state[0], "device"):
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
|
||||
Reference in New Issue
Block a user