add deployment mode. If /switch-model with deploy: true, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)
This commit is contained in:
@@ -15,6 +15,10 @@ class SwitchModelBody(BaseModel):
|
||||
strategy: str
|
||||
tokenizer: Union[str, None] = None
|
||||
customCuda: bool = False
|
||||
deploy: bool = Field(
|
||||
False,
|
||||
description="Deploy mode. If success, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@@ -23,6 +27,7 @@ class SwitchModelBody(BaseModel):
|
||||
"strategy": "cuda fp16",
|
||||
"tokenizer": None,
|
||||
"customCuda": False,
|
||||
"deploy": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,6 +35,9 @@ class SwitchModelBody(BaseModel):
|
||||
|
||||
@router.post("/switch-model", tags=["Configs"])
|
||||
def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
if global_var.get(global_var.Deploy_Mode) is True:
|
||||
raise HTTPException(Status.HTTP_403_FORBIDDEN)
|
||||
|
||||
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
|
||||
response.status_code = Status.HTTP_304_NOT_MODIFIED
|
||||
return
|
||||
@@ -65,6 +73,8 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
|
||||
)
|
||||
|
||||
if body.deploy:
|
||||
global_var.set(global_var.Deploy_Mode, True)
|
||||
if global_var.get(global_var.Model_Config) is None:
|
||||
global_var.set(
|
||||
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
||||
|
||||
Reference in New Issue
Block a user