RWKV-Runner/backend-python/routes/config.py
2023-05-19 15:59:04 +08:00

67 lines
1.9 KiB
Python

import pathlib
from fastapi import APIRouter, HTTPException, Response, status
from pydantic import BaseModel
from langchain.llms import RWKV
from utils.rwkv import *
from utils.torch import *
import global_var
router = APIRouter()
class SwitchModelBody(BaseModel):
model: str
strategy: str
@router.post("/switch-model")
def switch_model(body: SwitchModelBody, response: Response):
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = status.HTTP_304_NOT_MODIFIED
return
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
global_var.set(
global_var.Model,
RWKV(
model=body.model,
strategy=body.strategy,
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
),
)
except Exception as e:
print(e)
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
if global_var.get(global_var.Model_Config) is None:
global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
)
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"
@router.post("/update-config")
def update_config(body: ModelConfigBody):
"""
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
"""
print(body)
global_var.set(global_var.Model_Config, body)
return "success"
@router.get("/status")
def status():
return {"status": global_var.get(global_var.Model_Status)}