RWKV-Runner/backend-python/routes/config.py

48 lines
1.4 KiB
Python
Raw Normal View History

2023-05-07 17:27:54 +08:00
import pathlib
import sys
2023-05-07 22:48:52 +08:00
from fastapi import APIRouter, HTTPException, Response, status
2023-05-07 17:27:54 +08:00
from pydantic import BaseModel
from langchain.llms import RWKV
from utils.rwkv import *
from utils.torch import *
import global_var
router = APIRouter()
class UpdateConfigBody(BaseModel):
model: str = None
strategy: str = None
max_response_token: int = None
temperature: float = None
top_p: float = None
presence_penalty: float = None
count_penalty: float = None
@router.post("/update-config")
2023-05-07 22:48:52 +08:00
def update_config(body: UpdateConfigBody, response: Response):
2023-05-07 17:27:54 +08:00
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
2023-05-07 22:48:52 +08:00
response.status_code = status.HTTP_304_NOT_MODIFIED
return
2023-05-07 17:27:54 +08:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
global_var.set(global_var.Model, RWKV(
model=sys.argv[2],
strategy=sys.argv[1],
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
))
except Exception:
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"