add deployment mode. If /switch-model
with deploy: true
, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)
This commit is contained in:
parent
0594290b92
commit
7235e1067b
@ -4,6 +4,7 @@ Args = "args"
|
|||||||
Model = "model"
|
Model = "model"
|
||||||
Model_Status = "model_status"
|
Model_Status = "model_status"
|
||||||
Model_Config = "model_config"
|
Model_Config = "model_config"
|
||||||
|
Deploy_Mode = "deploy_mode"
|
||||||
|
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
@ -16,6 +17,7 @@ def init():
|
|||||||
global GLOBALS
|
global GLOBALS
|
||||||
GLOBALS = {}
|
GLOBALS = {}
|
||||||
set(Model_Status, ModelStatus.Offline)
|
set(Model_Status, ModelStatus.Offline)
|
||||||
|
set(Deploy_Mode, False)
|
||||||
|
|
||||||
|
|
||||||
def set(key, value):
|
def set(key, value):
|
||||||
|
@ -48,7 +48,7 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
@ -86,6 +86,9 @@ app.include_router(state_cache.router)
|
|||||||
|
|
||||||
@app.post("/exit", tags=["Root"])
|
@app.post("/exit", tags=["Root"])
|
||||||
def exit():
|
def exit():
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
parent_pid = os.getpid()
|
parent_pid = os.getpid()
|
||||||
parent = psutil.Process(parent_pid)
|
parent = psutil.Process(parent_pid)
|
||||||
for child in parent.children(recursive=True):
|
for child in parent.children(recursive=True):
|
||||||
|
@ -15,6 +15,10 @@ class SwitchModelBody(BaseModel):
|
|||||||
strategy: str
|
strategy: str
|
||||||
tokenizer: Union[str, None] = None
|
tokenizer: Union[str, None] = None
|
||||||
customCuda: bool = False
|
customCuda: bool = False
|
||||||
|
deploy: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Deploy mode. If success, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)",
|
||||||
|
)
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
@ -23,6 +27,7 @@ class SwitchModelBody(BaseModel):
|
|||||||
"strategy": "cuda fp16",
|
"strategy": "cuda fp16",
|
||||||
"tokenizer": None,
|
"tokenizer": None,
|
||||||
"customCuda": False,
|
"customCuda": False,
|
||||||
|
"deploy": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -30,6 +35,9 @@ class SwitchModelBody(BaseModel):
|
|||||||
|
|
||||||
@router.post("/switch-model", tags=["Configs"])
|
@router.post("/switch-model", tags=["Configs"])
|
||||||
def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(Status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
|
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
|
||||||
response.status_code = Status.HTTP_304_NOT_MODIFIED
|
response.status_code = Status.HTTP_304_NOT_MODIFIED
|
||||||
return
|
return
|
||||||
@ -65,6 +73,8 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
|||||||
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
|
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if body.deploy:
|
||||||
|
global_var.set(global_var.Deploy_Mode, True)
|
||||||
if global_var.get(global_var.Model_Config) is None:
|
if global_var.get(global_var.Model_Config) is None:
|
||||||
global_var.set(
|
global_var.set(
|
||||||
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import global_var
|
||||||
from fastapi import APIRouter, HTTPException, status
|
from fastapi import APIRouter, HTTPException, status
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -48,6 +49,9 @@ class TxtToMidiBody(BaseModel):
|
|||||||
|
|
||||||
@router.post("/txt-to-midi", tags=["MIDI"])
|
@router.post("/txt-to-midi", tags=["MIDI"])
|
||||||
def txt_to_midi(body: TxtToMidiBody):
|
def txt_to_midi(body: TxtToMidiBody):
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if not body.midi_path.startswith("midi/"):
|
if not body.midi_path.startswith("midi/"):
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
|
||||||
|
|
||||||
@ -84,6 +88,9 @@ def midi_to_wav(body: MidiToWavBody):
|
|||||||
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
|
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if not body.wav_path.startswith("midi/"):
|
if not body.wav_path.startswith("midi/"):
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
|
||||||
|
|
||||||
@ -115,6 +122,9 @@ def text_to_wav(body: TextToWavBody):
|
|||||||
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
|
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
text = body.text.strip()
|
text = body.text.strip()
|
||||||
if not text.startswith("<start>"):
|
if not text.startswith("<start>"):
|
||||||
text = "<start> " + text
|
text = "<start> " + text
|
||||||
|
@ -4,6 +4,7 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import gc
|
import gc
|
||||||
import copy
|
import copy
|
||||||
|
import global_var
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -36,6 +37,9 @@ def init():
|
|||||||
def disable_state_cache():
|
def disable_state_cache():
|
||||||
global trie, dtrie
|
global trie, dtrie
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
trie = None
|
trie = None
|
||||||
dtrie = {}
|
dtrie = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -46,6 +50,10 @@ def disable_state_cache():
|
|||||||
@router.post("/enable-state-cache", tags=["State Cache"])
|
@router.post("/enable-state-cache", tags=["State Cache"])
|
||||||
def enable_state_cache():
|
def enable_state_cache():
|
||||||
global trie, dtrie
|
global trie, dtrie
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cyac
|
import cyac
|
||||||
|
|
||||||
@ -68,6 +76,10 @@ class AddStateBody(BaseModel):
|
|||||||
@router.post("/add-state", tags=["State Cache"])
|
@router.post("/add-state", tags=["State Cache"])
|
||||||
def add_state(body: AddStateBody):
|
def add_state(body: AddStateBody):
|
||||||
global trie, dtrie, loop_del_trie_id
|
global trie, dtrie, loop_del_trie_id
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
@ -108,6 +120,10 @@ def add_state(body: AddStateBody):
|
|||||||
@router.post("/reset-state", tags=["State Cache"])
|
@router.post("/reset-state", tags=["State Cache"])
|
||||||
def reset_state():
|
def reset_state():
|
||||||
global trie, dtrie
|
global trie, dtrie
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
@ -144,6 +160,10 @@ def __get_a_dtrie_buff_size(dtrie_v):
|
|||||||
@router.post("/longest-prefix-state", tags=["State Cache"])
|
@router.post("/longest-prefix-state", tags=["State Cache"])
|
||||||
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||||
global trie
|
global trie
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
@ -183,6 +203,10 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
@router.post("/save-state", tags=["State Cache"])
|
@router.post("/save-state", tags=["State Cache"])
|
||||||
def save_state():
|
def save_state():
|
||||||
global trie
|
global trie
|
||||||
|
|
||||||
|
if global_var.get(global_var.Deploy_Mode) is True:
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user