From d9e25ad69fdb2e3fb1356a15eb3bfc1c49facb17 Mon Sep 17 00:00:00 2001 From: josc146 Date: Fri, 8 Dec 2023 15:28:33 +0800 Subject: [PATCH] better state cache --- backend-python/routes/config.py | 27 ++++++++++++++++++++------- backend-python/routes/state_cache.py | 26 +++++++++----------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index 6862723..da5bf28 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -49,13 +49,26 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): if body.model == "": return "success" - if "->" in body.strategy: - state_cache.disable_state_cache() - else: - try: - state_cache.enable_state_cache() - except HTTPException: - pass + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + if not re.match(STRATEGY_REGEX, body.strategy): + raise HTTPException( + Status.HTTP_400_BAD_REQUEST, + "Invalid strategy. Please read https://pypi.org/project/rwkv/", + ) + devices = set( + [ + x.strip().split(" ")[0].replace("cuda:0", "cuda") + for x in body.strategy.split("->") + ] + ) + print(f"Devices: {devices}") + # if len(devices) > 1: + # state_cache.disable_state_cache() + # else: + try: + state_cache.enable_state_cache() + except HTTPException: + pass os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0" diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 4b7ee8b..5d02f57 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -44,6 +44,7 @@ def disable_state_cache(): dtrie = {} gc.collect() + print("state cache disabled") return "success" @@ -61,8 +62,10 @@ def enable_state_cache(): dtrie = {} gc.collect() + print("state cache enabled") return "success" except ModuleNotFoundError: + print("state cache disabled") raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found") @@ -87,14 +90,12 @@ def add_state(body: AddStateBody): try: id: int = trie.insert(body.prompt) - device: torch.device = body.state[0].device + devices: List[torch.device] = [tensor.device for tensor in body.state] dtrie[id] = { "tokens": copy.deepcopy(body.tokens), - "state": [tensor.cpu() for tensor in body.state] - if device != torch.device("cpu") - else copy.deepcopy(body.state), + "state": [tensor.cpu() for tensor in body.state], "logits": copy.deepcopy(body.logits), - "device": device, + "devices": devices, } if len(trie) >= max_trie_len: @@ -177,27 +178,18 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): pass if id != -1: v = dtrie[id] - device: torch.device = v["device"] + devices: List[torch.device] = v["devices"] prompt: str = trie[id] quick_log(request, body, "Hit:\n" + prompt) return { "prompt": prompt, "tokens": v["tokens"], - "state": [tensor.to(device) for tensor in v["state"]] - if device != torch.device("cpu") - else v["state"], + "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])], "logits": v["logits"], - "device": device.type, } else: - return { - "prompt": "", - "tokens": [], - "state": None, - "logits": None, - "device": None, - } + return {"prompt": "", "tokens": [], "state": None, "logits": None} # @router.post("/save-state", tags=["State Cache"])