2023-05-07 17:27:54 +08:00
|
|
|
import pathlib
|
2023-06-09 20:46:19 +08:00
|
|
|
from utils.log import quick_log
|
2023-05-07 17:27:54 +08:00
|
|
|
|
2023-06-09 20:46:19 +08:00
|
|
|
from fastapi import APIRouter, HTTPException, Request, Response, status as Status
|
2023-05-07 17:27:54 +08:00
|
|
|
from pydantic import BaseModel
|
|
|
|
from utils.rwkv import *
|
|
|
|
from utils.torch import *
|
|
|
|
import global_var
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
2023-05-17 11:39:00 +08:00
|
|
|
class SwitchModelBody(BaseModel):
|
|
|
|
model: str
|
|
|
|
strategy: str
|
2023-09-16 00:34:11 +08:00
|
|
|
tokenizer: Union[str, None] = None
|
2023-05-23 11:51:43 +08:00
|
|
|
customCuda: bool = False
|
2023-05-07 17:27:54 +08:00
|
|
|
|
2023-06-15 21:52:22 +08:00
|
|
|
class Config:
|
2023-10-25 17:07:40 +08:00
|
|
|
json_schema_extra = {
|
2023-06-15 21:52:22 +08:00
|
|
|
"example": {
|
2023-06-20 16:07:52 +08:00
|
|
|
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
|
2023-06-15 21:52:22 +08:00
|
|
|
"strategy": "cuda fp16",
|
2023-09-16 00:34:11 +08:00
|
|
|
"tokenizer": None,
|
2023-06-15 21:52:22 +08:00
|
|
|
"customCuda": False,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-05-07 17:27:54 +08:00
|
|
|
|
2023-07-26 22:24:26 +08:00
|
|
|
@router.post("/switch-model", tags=["Configs"])
|
2023-06-09 20:46:19 +08:00
|
|
|
def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
2023-05-17 11:39:00 +08:00
|
|
|
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
|
2023-05-30 11:52:33 +08:00
|
|
|
response.status_code = Status.HTTP_304_NOT_MODIFIED
|
2023-05-07 22:48:52 +08:00
|
|
|
return
|
2023-05-07 17:27:54 +08:00
|
|
|
|
|
|
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
|
|
|
global_var.set(global_var.Model, None)
|
|
|
|
torch_gc()
|
2023-05-23 12:13:12 +08:00
|
|
|
|
2023-06-12 12:34:03 +08:00
|
|
|
if body.model == "":
|
|
|
|
return "success"
|
|
|
|
|
2023-07-11 11:20:12 +08:00
|
|
|
if "->" in body.strategy:
|
|
|
|
state_cache.disable_state_cache()
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
state_cache.enable_state_cache()
|
|
|
|
except HTTPException:
|
|
|
|
pass
|
|
|
|
|
2023-05-23 12:13:12 +08:00
|
|
|
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
2023-05-07 17:27:54 +08:00
|
|
|
|
|
|
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
|
|
|
try:
|
2023-05-17 11:39:00 +08:00
|
|
|
global_var.set(
|
|
|
|
global_var.Model,
|
2023-10-26 16:55:33 +08:00
|
|
|
RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
|
2023-05-17 11:39:00 +08:00
|
|
|
)
|
2023-05-18 21:19:13 +08:00
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
2023-06-09 20:46:19 +08:00
|
|
|
quick_log(request, body, f"Exception: {e}")
|
2023-05-07 17:27:54 +08:00
|
|
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
2023-06-15 21:52:22 +08:00
|
|
|
raise HTTPException(
|
|
|
|
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
|
|
|
|
)
|
2023-05-07 17:27:54 +08:00
|
|
|
|
2023-05-17 11:39:00 +08:00
|
|
|
if global_var.get(global_var.Model_Config) is None:
|
|
|
|
global_var.set(
|
|
|
|
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
|
|
|
)
|
2023-05-07 17:27:54 +08:00
|
|
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
|
|
|
|
|
|
|
return "success"
|
2023-05-17 11:39:00 +08:00
|
|
|
|
|
|
|
|
2023-07-26 22:24:26 +08:00
|
|
|
@router.post("/update-config", tags=["Configs"])
|
2023-05-17 11:39:00 +08:00
|
|
|
def update_config(body: ModelConfigBody):
|
|
|
|
"""
|
|
|
|
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
|
|
|
|
"""
|
|
|
|
|
|
|
|
print(body)
|
|
|
|
global_var.set(global_var.Model_Config, body)
|
|
|
|
|
|
|
|
return "success"
|
2023-05-19 15:59:04 +08:00
|
|
|
|
|
|
|
|
2023-07-26 22:24:26 +08:00
|
|
|
@router.get("/status", tags=["Configs"])
|
2023-05-19 15:59:04 +08:00
|
|
|
def status():
|
2023-07-29 19:18:01 +08:00
|
|
|
import GPUtil
|
|
|
|
|
2023-05-23 12:13:12 +08:00
|
|
|
gpus = GPUtil.getGPUs()
|
|
|
|
if len(gpus) == 0:
|
|
|
|
device_name = "CPU"
|
|
|
|
else:
|
|
|
|
device_name = gpus[0].name
|
|
|
|
return {
|
|
|
|
"status": global_var.get(global_var.Model_Status),
|
|
|
|
"pid": os.getpid(),
|
|
|
|
"device_name": device_name,
|
|
|
|
}
|