diff --git a/backend-python/global_var.py b/backend-python/global_var.py index 5f1215a..4553356 100644 --- a/backend-python/global_var.py +++ b/backend-python/global_var.py @@ -4,6 +4,7 @@ Args = "args" Model = "model" Model_Status = "model_status" Model_Config = "model_config" +Deploy_Mode = "deploy_mode" class ModelStatus(Enum): @@ -16,6 +17,7 @@ def init(): global GLOBALS GLOBALS = {} set(Model_Status, ModelStatus.Offline) + set(Deploy_Mode, False) def set(key, value): diff --git a/backend-python/main.py b/backend-python/main.py index 3f8e9d9..d098b1b 100644 --- a/backend-python/main.py +++ b/backend-python/main.py @@ -48,7 +48,7 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__))) import psutil from contextlib import asynccontextmanager -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, status from fastapi.middleware.cors import CORSMiddleware import uvicorn @@ -86,6 +86,9 @@ app.include_router(state_cache.router) @app.post("/exit", tags=["Root"]) def exit(): + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + parent_pid = os.getpid() parent = psutil.Process(parent_pid) for child in parent.children(recursive=True): diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index 7d5246c..ba489b8 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -15,6 +15,10 @@ class SwitchModelBody(BaseModel): strategy: str tokenizer: Union[str, None] = None customCuda: bool = False + deploy: bool = Field( + False, + description="Deploy mode. If success, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)", + ) model_config = { "json_schema_extra": { @@ -23,6 +27,7 @@ class SwitchModelBody(BaseModel): "strategy": "cuda fp16", "tokenizer": None, "customCuda": False, + "deploy": False, } } } @@ -30,6 +35,9 @@ class SwitchModelBody(BaseModel): @router.post("/switch-model", tags=["Configs"]) def switch_model(body: SwitchModelBody, response: Response, request: Request): + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(Status.HTTP_403_FORBIDDEN) + if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading: response.status_code = Status.HTTP_304_NOT_MODIFIED return @@ -65,6 +73,8 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}" ) + if body.deploy: + global_var.set(global_var.Deploy_Mode, True) if global_var.get(global_var.Model_Config) is None: global_var.set( global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model)) diff --git a/backend-python/routes/midi.py b/backend-python/routes/midi.py index c0b10b3..8c7bcc2 100644 --- a/backend-python/routes/midi.py +++ b/backend-python/routes/midi.py @@ -1,4 +1,5 @@ import io +import global_var from fastapi import APIRouter, HTTPException, status from starlette.responses import StreamingResponse from pydantic import BaseModel @@ -48,6 +49,9 @@ class TxtToMidiBody(BaseModel): @router.post("/txt-to-midi", tags=["MIDI"]) def txt_to_midi(body: TxtToMidiBody): + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + if not body.midi_path.startswith("midi/"): raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path") @@ -84,6 +88,9 @@ def midi_to_wav(body: MidiToWavBody): Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions """ + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + if not body.wav_path.startswith("midi/"): raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path") @@ -115,6 +122,9 @@ def text_to_wav(body: TextToWavBody): Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions """ + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + text = body.text.strip() if not text.startswith(""): text = " " + text diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 443329d..bc403ff 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -4,6 +4,7 @@ from fastapi import APIRouter, HTTPException, Request, Response, status from pydantic import BaseModel import gc import copy +import global_var router = APIRouter() @@ -36,6 +37,9 @@ def init(): def disable_state_cache(): global trie, dtrie + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + trie = None dtrie = {} gc.collect() @@ -46,6 +50,10 @@ def disable_state_cache(): @router.post("/enable-state-cache", tags=["State Cache"]) def enable_state_cache(): global trie, dtrie + + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + try: import cyac @@ -68,6 +76,10 @@ class AddStateBody(BaseModel): @router.post("/add-state", tags=["State Cache"]) def add_state(body: AddStateBody): global trie, dtrie, loop_del_trie_id + + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") @@ -108,6 +120,10 @@ def add_state(body: AddStateBody): @router.post("/reset-state", tags=["State Cache"]) def reset_state(): global trie, dtrie + + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") @@ -144,6 +160,10 @@ def __get_a_dtrie_buff_size(dtrie_v): @router.post("/longest-prefix-state", tags=["State Cache"]) def longest_prefix_state(body: LongestPrefixStateBody, request: Request): global trie + + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") @@ -183,6 +203,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): @router.post("/save-state", tags=["State Cache"]) def save_state(): global trie + + if global_var.get(global_var.Deploy_Mode) is True: + raise HTTPException(status.HTTP_403_FORBIDDEN) + if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")