preliminary usable features
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Response, status
|
||||
from pydantic import BaseModel
|
||||
@@ -11,19 +10,14 @@ import global_var
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UpdateConfigBody(BaseModel):
|
||||
model: str = None
|
||||
strategy: str = None
|
||||
max_response_token: int = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
presence_penalty: float = None
|
||||
count_penalty: float = None
|
||||
class SwitchModelBody(BaseModel):
|
||||
model: str
|
||||
strategy: str
|
||||
|
||||
|
||||
@router.post("/update-config")
|
||||
def update_config(body: UpdateConfigBody, response: Response):
|
||||
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
|
||||
@router.post("/switch-model")
|
||||
def switch_model(body: SwitchModelBody, response: Response):
|
||||
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
|
||||
response.status_code = status.HTTP_304_NOT_MODIFIED
|
||||
return
|
||||
|
||||
@@ -33,15 +27,34 @@ def update_config(body: UpdateConfigBody, response: Response):
|
||||
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||
try:
|
||||
global_var.set(global_var.Model, RWKV(
|
||||
model=sys.argv[2],
|
||||
strategy=sys.argv[1],
|
||||
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
||||
))
|
||||
global_var.set(
|
||||
global_var.Model,
|
||||
RWKV(
|
||||
model=body.model,
|
||||
strategy=body.strategy,
|
||||
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
|
||||
|
||||
if global_var.get(global_var.Model_Config) is None:
|
||||
global_var.set(
|
||||
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
||||
)
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/update-config")
|
||||
def update_config(body: ModelConfigBody):
|
||||
"""
|
||||
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
|
||||
"""
|
||||
|
||||
print(body)
|
||||
global_var.set(global_var.Model_Config, body)
|
||||
|
||||
return "success"
|
||||
|
||||
Reference in New Issue
Block a user