add deployment mode. If /switch-model with deploy: true, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)

This commit is contained in:
josc146 2023-11-08 23:29:42 +08:00
parent 0594290b92
commit 7235e1067b
5 changed files with 50 additions and 1 deletions

View File

@ -4,6 +4,7 @@ Args = "args"
Model = "model" Model = "model"
Model_Status = "model_status" Model_Status = "model_status"
Model_Config = "model_config" Model_Config = "model_config"
Deploy_Mode = "deploy_mode"
class ModelStatus(Enum): class ModelStatus(Enum):
@ -16,6 +17,7 @@ def init():
global GLOBALS global GLOBALS
GLOBALS = {} GLOBALS = {}
set(Model_Status, ModelStatus.Offline) set(Model_Status, ModelStatus.Offline)
set(Deploy_Mode, False)
def set(key, value): def set(key, value):

View File

@ -48,7 +48,7 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import psutil import psutil
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
@ -86,6 +86,9 @@ app.include_router(state_cache.router)
@app.post("/exit", tags=["Root"]) @app.post("/exit", tags=["Root"])
def exit(): def exit():
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
parent_pid = os.getpid() parent_pid = os.getpid()
parent = psutil.Process(parent_pid) parent = psutil.Process(parent_pid)
for child in parent.children(recursive=True): for child in parent.children(recursive=True):

View File

@ -15,6 +15,10 @@ class SwitchModelBody(BaseModel):
strategy: str strategy: str
tokenizer: Union[str, None] = None tokenizer: Union[str, None] = None
customCuda: bool = False customCuda: bool = False
deploy: bool = Field(
False,
description="Deploy mode. If success, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)",
)
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
@ -23,6 +27,7 @@ class SwitchModelBody(BaseModel):
"strategy": "cuda fp16", "strategy": "cuda fp16",
"tokenizer": None, "tokenizer": None,
"customCuda": False, "customCuda": False,
"deploy": False,
} }
} }
} }
@ -30,6 +35,9 @@ class SwitchModelBody(BaseModel):
@router.post("/switch-model", tags=["Configs"]) @router.post("/switch-model", tags=["Configs"])
def switch_model(body: SwitchModelBody, response: Response, request: Request): def switch_model(body: SwitchModelBody, response: Response, request: Request):
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(Status.HTTP_403_FORBIDDEN)
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading: if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = Status.HTTP_304_NOT_MODIFIED response.status_code = Status.HTTP_304_NOT_MODIFIED
return return
@ -65,6 +73,8 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}" Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
) )
if body.deploy:
global_var.set(global_var.Deploy_Mode, True)
if global_var.get(global_var.Model_Config) is None: if global_var.get(global_var.Model_Config) is None:
global_var.set( global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model)) global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))

View File

@ -1,4 +1,5 @@
import io import io
import global_var
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -48,6 +49,9 @@ class TxtToMidiBody(BaseModel):
@router.post("/txt-to-midi", tags=["MIDI"]) @router.post("/txt-to-midi", tags=["MIDI"])
def txt_to_midi(body: TxtToMidiBody): def txt_to_midi(body: TxtToMidiBody):
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if not body.midi_path.startswith("midi/"): if not body.midi_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path") raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
@ -84,6 +88,9 @@ def midi_to_wav(body: MidiToWavBody):
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
""" """
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if not body.wav_path.startswith("midi/"): if not body.wav_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path") raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
@ -115,6 +122,9 @@ def text_to_wav(body: TextToWavBody):
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
""" """
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
text = body.text.strip() text = body.text.strip()
if not text.startswith("<start>"): if not text.startswith("<start>"):
text = "<start> " + text text = "<start> " + text

View File

@ -4,6 +4,7 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel from pydantic import BaseModel
import gc import gc
import copy import copy
import global_var
router = APIRouter() router = APIRouter()
@ -36,6 +37,9 @@ def init():
def disable_state_cache(): def disable_state_cache():
global trie, dtrie global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
trie = None trie = None
dtrie = {} dtrie = {}
gc.collect() gc.collect()
@ -46,6 +50,10 @@ def disable_state_cache():
@router.post("/enable-state-cache", tags=["State Cache"]) @router.post("/enable-state-cache", tags=["State Cache"])
def enable_state_cache(): def enable_state_cache():
global trie, dtrie global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
try: try:
import cyac import cyac
@ -68,6 +76,10 @@ class AddStateBody(BaseModel):
@router.post("/add-state", tags=["State Cache"]) @router.post("/add-state", tags=["State Cache"])
def add_state(body: AddStateBody): def add_state(body: AddStateBody):
global trie, dtrie, loop_del_trie_id global trie, dtrie, loop_del_trie_id
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@ -108,6 +120,10 @@ def add_state(body: AddStateBody):
@router.post("/reset-state", tags=["State Cache"]) @router.post("/reset-state", tags=["State Cache"])
def reset_state(): def reset_state():
global trie, dtrie global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@ -144,6 +160,10 @@ def __get_a_dtrie_buff_size(dtrie_v):
@router.post("/longest-prefix-state", tags=["State Cache"]) @router.post("/longest-prefix-state", tags=["State Cache"])
def longest_prefix_state(body: LongestPrefixStateBody, request: Request): def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie global trie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@ -183,6 +203,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
@router.post("/save-state", tags=["State Cache"]) @router.post("/save-state", tags=["State Cache"])
def save_state(): def save_state():
global trie global trie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")