improve state cache performance
This commit is contained in:
parent
e083f2c629
commit
7f3cfd54b0
@ -94,28 +94,19 @@ def add_state(body: AddStateBody):
|
|||||||
state: Union[Any, None] = None
|
state: Union[Any, None] = None
|
||||||
|
|
||||||
if body.state is not None:
|
if body.state is not None:
|
||||||
if type(body.state) == list or type(body.state) == np.ndarray:
|
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||||
devices = [
|
devices = [tensor.device for tensor in body.state]
|
||||||
(
|
state = [tensor.cpu() for tensor in body.state]
|
||||||
tensor.device
|
elif type(body.state) == np.ndarray: # rwkv.cpp
|
||||||
if hasattr(tensor, "device")
|
state = body.state
|
||||||
else torch.device("cpu")
|
else: # WebGPU
|
||||||
)
|
state = body.state.back()
|
||||||
for tensor in body.state
|
|
||||||
]
|
|
||||||
state = (
|
|
||||||
[tensor.cpu() for tensor in body.state]
|
|
||||||
if hasattr(body.state[0], "device")
|
|
||||||
else copy.deepcopy(body.state)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
state = body.state.back() # WebGPU
|
|
||||||
|
|
||||||
id: int = trie.insert(body.prompt)
|
id: int = trie.insert(body.prompt)
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
"tokens": copy.deepcopy(body.tokens),
|
"tokens": body.tokens,
|
||||||
"state": state,
|
"state": state,
|
||||||
"logits": copy.deepcopy(body.logits),
|
"logits": body.logits,
|
||||||
"devices": devices,
|
"devices": devices,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if id != -1:
|
if id != -1:
|
||||||
|
prompt: str = trie[id]
|
||||||
v = dtrie[id]
|
v = dtrie[id]
|
||||||
devices: List[torch.device] = v["devices"]
|
devices: List[torch.device] = v["devices"]
|
||||||
prompt: str = trie[id]
|
|
||||||
state: Union[Any, None] = v["state"]
|
state: Union[Any, None] = v["state"]
|
||||||
|
|
||||||
if state is not None and type(state) == list and hasattr(state[0], "device"):
|
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||||
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
||||||
|
|
||||||
quick_log(request, body, "Hit:\n" + prompt)
|
quick_log(request, body, "Hit:\n" + prompt)
|
||||||
|
@ -239,14 +239,9 @@ class AbstractRWKV(ABC):
|
|||||||
self.model_tokens = []
|
self.model_tokens = []
|
||||||
else:
|
else:
|
||||||
delta_prompt = prompt[len(cache["prompt"]) :]
|
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||||
state = cache["state"]
|
self.model_state = cache["state"]
|
||||||
self.model_state = (
|
self.model_tokens = cache["tokens"]
|
||||||
copy.deepcopy(state)
|
logits = cache["logits"]
|
||||||
if type(state) == list or type(state) == np.ndarray
|
|
||||||
else state
|
|
||||||
)
|
|
||||||
self.model_tokens = copy.deepcopy(cache["tokens"])
|
|
||||||
logits = copy.deepcopy(cache["logits"])
|
|
||||||
|
|
||||||
prompt_token_len = 0
|
prompt_token_len = 0
|
||||||
if delta_prompt != "":
|
if delta_prompt != "":
|
||||||
|
Loading…
Reference in New Issue
Block a user