add deployment mode. If /switch-model with deploy: true, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)

This commit is contained in:
josc146
2023-11-08 23:29:42 +08:00
parent 0594290b92
commit 7235e1067b
5 changed files with 50 additions and 1 deletions

View File

@@ -4,6 +4,7 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
import global_var
router = APIRouter()
@@ -36,6 +37,9 @@ def init():
def disable_state_cache():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
trie = None
dtrie = {}
gc.collect()
@@ -46,6 +50,10 @@ def disable_state_cache():
@router.post("/enable-state-cache", tags=["State Cache"])
def enable_state_cache():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
try:
import cyac
@@ -68,6 +76,10 @@ class AddStateBody(BaseModel):
@router.post("/add-state", tags=["State Cache"])
def add_state(body: AddStateBody):
global trie, dtrie, loop_del_trie_id
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -108,6 +120,10 @@ def add_state(body: AddStateBody):
@router.post("/reset-state", tags=["State Cache"])
def reset_state():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -144,6 +160,10 @@ def __get_a_dtrie_buff_size(dtrie_v):
@router.post("/longest-prefix-state", tags=["State Cache"])
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -183,6 +203,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
@router.post("/save-state", tags=["State Cache"])
def save_state():
global trie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")