diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index f6f6fc2..d4513f9 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -52,6 +52,14 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): if body.model == "": return "success" + if "->" in body.strategy: + state_cache.disable_state_cache() + else: + try: + state_cache.enable_state_cache() + except HTTPException: + pass + os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0" global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 11d7026..ec6d81b 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -34,6 +34,32 @@ def init(): print("cyac not found") +@router.post("/disable-state-cache") +def disable_state_cache(): + global trie, dtrie + + trie = None + dtrie = {} + gc.collect() + + return "success" + + +@router.post("/enable-state-cache") +def enable_state_cache(): + global trie, dtrie + try: + import cyac + + trie = cyac.Trie() + dtrie = {} + gc.collect() + + return "success" + except ModuleNotFoundError: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found") + + class AddStateBody(BaseModel): prompt: str tokens: List[str] @@ -85,6 +111,8 @@ def reset_state(): if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + import cyac + trie = cyac.Trie() dtrie = {} gc.collect()