diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 647ffac..6178daa 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -94,28 +94,19 @@ def add_state(body: AddStateBody): state: Union[Any, None] = None if body.state is not None: - if type(body.state) == list or type(body.state) == np.ndarray: - devices = [ - ( - tensor.device - if hasattr(tensor, "device") - else torch.device("cpu") - ) - for tensor in body.state - ] - state = ( - [tensor.cpu() for tensor in body.state] - if hasattr(body.state[0], "device") - else copy.deepcopy(body.state) - ) - else: - state = body.state.back() # WebGPU + if type(state) == list and hasattr(state[0], "device"): # torch + devices = [tensor.device for tensor in body.state] + state = [tensor.cpu() for tensor in body.state] + elif type(body.state) == np.ndarray: # rwkv.cpp + state = body.state + else: # WebGPU + state = body.state.back() id: int = trie.insert(body.prompt) dtrie[id] = { - "tokens": copy.deepcopy(body.tokens), + "tokens": body.tokens, "state": state, - "logits": copy.deepcopy(body.logits), + "logits": body.logits, "devices": devices, } @@ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): except: pass if id != -1: + prompt: str = trie[id] v = dtrie[id] devices: List[torch.device] = v["devices"] - prompt: str = trie[id] state: Union[Any, None] = v["state"] - if state is not None and type(state) == list and hasattr(state[0], "device"): + if type(state) == list and hasattr(state[0], "device"): # torch state = [tensor.to(devices[i]) for i, tensor in enumerate(state)] quick_log(request, body, "Hit:\n" + prompt) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 9109c47..2317a20 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -239,14 +239,9 @@ class AbstractRWKV(ABC): self.model_tokens = [] else: delta_prompt = prompt[len(cache["prompt"]) :] - state = cache["state"] - self.model_state = ( - copy.deepcopy(state) - if type(state) == list or type(state) == np.ndarray - else state - ) - self.model_tokens = copy.deepcopy(cache["tokens"]) - logits = copy.deepcopy(cache["logits"]) + self.model_state = cache["state"] + self.model_tokens = cache["tokens"] + logits = cache["logits"] prompt_token_len = 0 if delta_prompt != "":