From 994fc7c828f8f6538f7cb889dc235e6d36522387 Mon Sep 17 00:00:00 2001 From: josc146 Date: Tue, 11 Jul 2023 11:20:12 +0800 Subject: [PATCH] fix cross-device state cache exception --- backend-python/routes/config.py | 8 ++++++++ backend-python/routes/state_cache.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index f6f6fc2..d4513f9 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -52,6 +52,14 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): if body.model == "": return "success" + if "->" in body.strategy: + state_cache.disable_state_cache() + else: + try: + state_cache.enable_state_cache() + except HTTPException: + pass + os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0" global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 11d7026..ec6d81b 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -34,6 +34,32 @@ def init(): print("cyac not found") +@router.post("/disable-state-cache") +def disable_state_cache(): + global trie, dtrie + + trie = None + dtrie = {} + gc.collect() + + return "success" + + +@router.post("/enable-state-cache") +def enable_state_cache(): + global trie, dtrie + try: + import cyac + + trie = cyac.Trie() + dtrie = {} + gc.collect() + + return "success" + except ModuleNotFoundError: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found") + + class AddStateBody(BaseModel): prompt: str tokens: List[str] @@ -85,6 +111,8 @@ def reset_state(): if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + import cyac + trie = cyac.Trie() dtrie = {} gc.collect()