improve state cache performance
This commit is contained in:
parent
e083f2c629
commit
7f3cfd54b0
@ -94,28 +94,19 @@ def add_state(body: AddStateBody):
|
||||
state: Union[Any, None] = None
|
||||
|
||||
if body.state is not None:
|
||||
if type(body.state) == list or type(body.state) == np.ndarray:
|
||||
devices = [
|
||||
(
|
||||
tensor.device
|
||||
if hasattr(tensor, "device")
|
||||
else torch.device("cpu")
|
||||
)
|
||||
for tensor in body.state
|
||||
]
|
||||
state = (
|
||||
[tensor.cpu() for tensor in body.state]
|
||||
if hasattr(body.state[0], "device")
|
||||
else copy.deepcopy(body.state)
|
||||
)
|
||||
else:
|
||||
state = body.state.back() # WebGPU
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
devices = [tensor.device for tensor in body.state]
|
||||
state = [tensor.cpu() for tensor in body.state]
|
||||
elif type(body.state) == np.ndarray: # rwkv.cpp
|
||||
state = body.state
|
||||
else: # WebGPU
|
||||
state = body.state.back()
|
||||
|
||||
id: int = trie.insert(body.prompt)
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"tokens": body.tokens,
|
||||
"state": state,
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"logits": body.logits,
|
||||
"devices": devices,
|
||||
}
|
||||
|
||||
@ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
except:
|
||||
pass
|
||||
if id != -1:
|
||||
prompt: str = trie[id]
|
||||
v = dtrie[id]
|
||||
devices: List[torch.device] = v["devices"]
|
||||
prompt: str = trie[id]
|
||||
state: Union[Any, None] = v["state"]
|
||||
|
||||
if state is not None and type(state) == list and hasattr(state[0], "device"):
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
|
@ -239,14 +239,9 @@ class AbstractRWKV(ABC):
|
||||
self.model_tokens = []
|
||||
else:
|
||||
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||
state = cache["state"]
|
||||
self.model_state = (
|
||||
copy.deepcopy(state)
|
||||
if type(state) == list or type(state) == np.ndarray
|
||||
else state
|
||||
)
|
||||
self.model_tokens = copy.deepcopy(cache["tokens"])
|
||||
logits = copy.deepcopy(cache["logits"])
|
||||
self.model_state = cache["state"]
|
||||
self.model_tokens = cache["tokens"]
|
||||
logits = cache["logits"]
|
||||
|
||||
prompt_token_len = 0
|
||||
if delta_prompt != "":
|
||||
|
Loading…
Reference in New Issue
Block a user