diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 6c124be..17f373c 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -76,6 +76,31 @@ class AddStateBody(BaseModel): logits: Any +def copy_tensor_to_cpu(tensors): + import torch + import numpy as np + + devices: List[torch.device] = [] + copied: Union[Any, None] = None + + tensors_type = type(tensors) + if tensors_type == list: + if hasattr(tensors[0], "device"): # torch state + devices = [tensor.device for tensor in tensors] + copied = [tensor.cpu() for tensor in tensors] + else: # WebGPU logits + copied = tensors + elif tensors_type == torch.Tensor: # torch logits + devices = [tensors.device] + copied = tensors.cpu() + elif tensors_type == np.ndarray: # rwkv.cpp + copied = tensors + else: # WebGPU state + copied = tensors.back() + + return copied, devices + + # @router.post("/add-state", tags=["State Cache"]) def add_state(body: AddStateBody): global trie, dtrie, loop_del_trie_id @@ -91,23 +116,24 @@ def add_state(body: AddStateBody): try: devices: List[torch.device] = [] + logits_device: Union[torch.device, None] = None state: Union[Any, None] = None + logits: Union[Any, None] = None if body.state is not None: - if type(body.state) == list and hasattr(body.state[0], "device"): # torch - devices = [tensor.device for tensor in body.state] - state = [tensor.cpu() for tensor in body.state] - elif type(body.state) == np.ndarray: # rwkv.cpp - state = body.state - else: # WebGPU - state = body.state.back() + state, devices = copy_tensor_to_cpu(body.state) + if body.logits is not None: + logits, logits_devices = copy_tensor_to_cpu(body.logits) + if len(logits_devices) > 0: + logits_device = logits_devices[0] id: int = trie.insert(body.prompt) dtrie[id] = { "tokens": body.tokens, "state": state, - "logits": body.logits, + "logits": logits, "devices": devices, + "logits_device": logits_device, } if len(trie) >= max_trie_len: @@ -125,6 +151,7 @@ def add_state(body: AddStateBody): ) return "success" except Exception as e: + print(e) # should not happen raise HTTPException( status.HTTP_400_BAD_REQUEST, f"insert failed, bad prompt.\n{e}" ) @@ -192,18 +219,33 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): if id != -1: prompt: str = trie[id] v = dtrie[id] + tokens: List[Union[str, int]] = copy.deepcopy(v["tokens"]) devices: List[torch.device] = v["devices"] + logits_device: Union[torch.device, None] = v["logits_device"] state: Union[Any, None] = v["state"] + logits: Union[Any, None] = v["logits"] if type(state) == list and hasattr(state[0], "device"): # torch - state = [tensor.to(devices[i]) for i, tensor in enumerate(state)] + state = [ + tensor.to(devices[i]) + if devices[i] != torch.device("cpu") + else tensor.clone() + for i, tensor in enumerate(state) + ] + logits = ( + logits.to(logits_device) + if logits_device != torch.device("cpu") + else logits.clone() + ) + else: # rwkv.cpp, WebGPU + logits = np.copy(logits) quick_log(request, body, "Hit:\n" + prompt) return { "prompt": prompt, - "tokens": v["tokens"], + "tokens": tokens, "state": state, - "logits": v["logits"], + "logits": logits, } else: return {"prompt": "", "tokens": [], "state": None, "logits": None}