RWKV-Runner/backend-python/routes/config.py

126 lines
3.7 KiB
Python
Raw Normal View History

2023-05-07 09:27:54 +00:00
import pathlib
from utils.log import quick_log
2023-05-07 09:27:54 +00:00
from fastapi import APIRouter, HTTPException, Request, Response, status as Status
2023-05-07 09:27:54 +00:00
from pydantic import BaseModel
from utils.rwkv import *
from utils.torch import *
import global_var
router = APIRouter()
2023-05-17 03:39:00 +00:00
class SwitchModelBody(BaseModel):
model: str
strategy: str
2023-09-15 16:34:11 +00:00
tokenizer: Union[str, None] = None
2023-05-23 03:51:43 +00:00
customCuda: bool = False
deploy: bool = Field(
False,
description="Deploy mode. If success, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)",
)
2023-05-07 09:27:54 +00:00
model_config = {
"json_schema_extra": {
2023-06-15 13:52:22 +00:00
"example": {
2023-06-20 08:07:52 +00:00
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
2023-06-15 13:52:22 +00:00
"strategy": "cuda fp16",
2023-11-20 12:11:45 +00:00
"tokenizer": "",
2023-06-15 13:52:22 +00:00
"customCuda": False,
"deploy": False,
2023-06-15 13:52:22 +00:00
}
}
}
2023-06-15 13:52:22 +00:00
2023-05-07 09:27:54 +00:00
2023-07-26 14:24:26 +00:00
@router.post("/switch-model", tags=["Configs"])
def switch_model(body: SwitchModelBody, response: Response, request: Request):
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(Status.HTTP_403_FORBIDDEN)
2023-05-17 03:39:00 +00:00
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
2023-05-30 03:52:33 +00:00
response.status_code = Status.HTTP_304_NOT_MODIFIED
2023-05-07 14:48:52 +00:00
return
2023-05-07 09:27:54 +00:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
2023-05-23 04:13:12 +00:00
if body.model == "":
return "success"
2023-12-08 07:28:33 +00:00
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
if not re.match(STRATEGY_REGEX, body.strategy):
raise HTTPException(
Status.HTTP_400_BAD_REQUEST,
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
)
devices = set(
[
x.strip().split(" ")[0].replace("cuda:0", "cuda")
for x in body.strategy.split("->")
]
)
print(f"Devices: {devices}")
# if len(devices) > 1:
# state_cache.disable_state_cache()
# else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
2023-07-11 03:20:12 +00:00
2023-05-23 04:13:12 +00:00
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
2023-05-07 09:27:54 +00:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
2023-05-17 03:39:00 +00:00
global_var.set(
global_var.Model,
RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
2023-05-17 03:39:00 +00:00
)
2023-05-18 13:19:13 +00:00
except Exception as e:
print(e)
quick_log(request, body, f"Exception: {e}")
2023-05-07 09:27:54 +00:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
2023-06-15 13:52:22 +00:00
raise HTTPException(
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
)
2023-05-07 09:27:54 +00:00
if body.deploy:
global_var.set(global_var.Deploy_Mode, True)
2023-05-17 03:39:00 +00:00
if global_var.get(global_var.Model_Config) is None:
global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
)
2023-05-07 09:27:54 +00:00
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"
2023-05-17 03:39:00 +00:00
2023-07-26 14:24:26 +00:00
@router.post("/update-config", tags=["Configs"])
2023-05-17 03:39:00 +00:00
def update_config(body: ModelConfigBody):
"""
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
"""
print(body)
global_var.set(global_var.Model_Config, body)
return "success"
2023-05-19 07:59:04 +00:00
2023-07-26 14:24:26 +00:00
@router.get("/status", tags=["Configs"])
2023-05-19 07:59:04 +00:00
def status():
2023-07-29 11:18:01 +00:00
import GPUtil
2023-05-23 04:13:12 +00:00
gpus = GPUtil.getGPUs()
if len(gpus) == 0:
device_name = "CPU"
else:
device_name = gpus[0].name
return {
"status": global_var.get(global_var.Model_Status),
"pid": os.getpid(),
"device_name": device_name,
}