RWKV-Runner/backend-python/routes/config.py

92 lines
2.5 KiB
Python
Raw Normal View History

2023-05-07 17:27:54 +08:00
import pathlib
2023-05-30 11:52:33 +08:00
from fastapi import APIRouter, HTTPException, Response, status as Status
2023-05-07 17:27:54 +08:00
from pydantic import BaseModel
from utils.rwkv import *
from utils.torch import *
import global_var
2023-05-23 12:13:12 +08:00
import GPUtil
2023-05-07 17:27:54 +08:00
router = APIRouter()
2023-05-28 12:53:14 +08:00
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
2023-05-28 12:53:14 +08:00
)
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
else:
return default_tokens_path
2023-05-17 11:39:00 +08:00
class SwitchModelBody(BaseModel):
model: str
strategy: str
2023-05-23 11:51:43 +08:00
customCuda: bool = False
2023-05-07 17:27:54 +08:00
2023-05-17 11:39:00 +08:00
@router.post("/switch-model")
def switch_model(body: SwitchModelBody, response: Response):
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
2023-05-30 11:52:33 +08:00
response.status_code = Status.HTTP_304_NOT_MODIFIED
2023-05-07 22:48:52 +08:00
return
2023-05-07 17:27:54 +08:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
2023-05-23 12:13:12 +08:00
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
2023-05-07 17:27:54 +08:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
2023-05-17 11:39:00 +08:00
global_var.set(
global_var.Model,
RWKV(
model=body.model,
strategy=body.strategy,
2023-05-28 12:53:14 +08:00
tokens_path=get_tokens_path(body.model),
2023-05-17 11:39:00 +08:00
),
)
2023-05-18 21:19:13 +08:00
except Exception as e:
print(e)
2023-05-07 17:27:54 +08:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
2023-05-30 11:52:33 +08:00
raise HTTPException(Status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
2023-05-07 17:27:54 +08:00
2023-05-17 11:39:00 +08:00
if global_var.get(global_var.Model_Config) is None:
global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
)
2023-05-07 17:27:54 +08:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"
2023-05-17 11:39:00 +08:00
@router.post("/update-config")
def update_config(body: ModelConfigBody):
"""
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
"""
print(body)
global_var.set(global_var.Model_Config, body)
return "success"
2023-05-19 15:59:04 +08:00
@router.get("/status")
def status():
2023-05-23 12:13:12 +08:00
gpus = GPUtil.getGPUs()
if len(gpus) == 0:
device_name = "CPU"
else:
device_name = gpus[0].name
return {
"status": global_var.get(global_var.Model_Status),
"pid": os.getpid(),
"device_name": device_name,
}