improve state cache performance

This commit is contained in:
josc146 2023-12-28 22:15:31 +08:00
parent e083f2c629
commit 7f3cfd54b0
2 changed files with 14 additions and 28 deletions

View File

@ -94,28 +94,19 @@ def add_state(body: AddStateBody):
state: Union[Any, None] = None state: Union[Any, None] = None
if body.state is not None: if body.state is not None:
if type(body.state) == list or type(body.state) == np.ndarray: if type(state) == list and hasattr(state[0], "device"): # torch
devices = [ devices = [tensor.device for tensor in body.state]
( state = [tensor.cpu() for tensor in body.state]
tensor.device elif type(body.state) == np.ndarray: # rwkv.cpp
if hasattr(tensor, "device") state = body.state
else torch.device("cpu") else: # WebGPU
) state = body.state.back()
for tensor in body.state
]
state = (
[tensor.cpu() for tensor in body.state]
if hasattr(body.state[0], "device")
else copy.deepcopy(body.state)
)
else:
state = body.state.back() # WebGPU
id: int = trie.insert(body.prompt) id: int = trie.insert(body.prompt)
dtrie[id] = { dtrie[id] = {
"tokens": copy.deepcopy(body.tokens), "tokens": body.tokens,
"state": state, "state": state,
"logits": copy.deepcopy(body.logits), "logits": body.logits,
"devices": devices, "devices": devices,
} }
@ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
except: except:
pass pass
if id != -1: if id != -1:
prompt: str = trie[id]
v = dtrie[id] v = dtrie[id]
devices: List[torch.device] = v["devices"] devices: List[torch.device] = v["devices"]
prompt: str = trie[id]
state: Union[Any, None] = v["state"] state: Union[Any, None] = v["state"]
if state is not None and type(state) == list and hasattr(state[0], "device"): if type(state) == list and hasattr(state[0], "device"): # torch
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)] state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
quick_log(request, body, "Hit:\n" + prompt) quick_log(request, body, "Hit:\n" + prompt)

View File

@ -239,14 +239,9 @@ class AbstractRWKV(ABC):
self.model_tokens = [] self.model_tokens = []
else: else:
delta_prompt = prompt[len(cache["prompt"]) :] delta_prompt = prompt[len(cache["prompt"]) :]
state = cache["state"] self.model_state = cache["state"]
self.model_state = ( self.model_tokens = cache["tokens"]
copy.deepcopy(state) logits = cache["logits"]
if type(state) == list or type(state) == np.ndarray
else state
)
self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"])
prompt_token_len = 0 prompt_token_len = 0
if delta_prompt != "": if delta_prompt != "":