improve state cache performance

This commit is contained in:
josc146 2023-12-28 22:15:31 +08:00
parent e083f2c629
commit 7f3cfd54b0
2 changed files with 14 additions and 28 deletions

View File

@ -94,28 +94,19 @@ def add_state(body: AddStateBody):
state: Union[Any, None] = None
if body.state is not None:
if type(body.state) == list or type(body.state) == np.ndarray:
devices = [
(
tensor.device
if hasattr(tensor, "device")
else torch.device("cpu")
)
for tensor in body.state
]
state = (
[tensor.cpu() for tensor in body.state]
if hasattr(body.state[0], "device")
else copy.deepcopy(body.state)
)
else:
state = body.state.back() # WebGPU
if type(state) == list and hasattr(state[0], "device"): # torch
devices = [tensor.device for tensor in body.state]
state = [tensor.cpu() for tensor in body.state]
elif type(body.state) == np.ndarray: # rwkv.cpp
state = body.state
else: # WebGPU
state = body.state.back()
id: int = trie.insert(body.prompt)
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"tokens": body.tokens,
"state": state,
"logits": copy.deepcopy(body.logits),
"logits": body.logits,
"devices": devices,
}
@ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
except:
pass
if id != -1:
prompt: str = trie[id]
v = dtrie[id]
devices: List[torch.device] = v["devices"]
prompt: str = trie[id]
state: Union[Any, None] = v["state"]
if state is not None and type(state) == list and hasattr(state[0], "device"):
if type(state) == list and hasattr(state[0], "device"): # torch
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
quick_log(request, body, "Hit:\n" + prompt)

View File

@ -239,14 +239,9 @@ class AbstractRWKV(ABC):
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
state = cache["state"]
self.model_state = (
copy.deepcopy(state)
if type(state) == list or type(state) == np.ndarray
else state
)
self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"])
self.model_state = cache["state"]
self.model_tokens = cache["tokens"]
logits = cache["logits"]
prompt_token_len = 0
if delta_prompt != "":