add logs for state cache and switch-model

This commit is contained in:
josc146
2023-06-09 20:46:19 +08:00
parent b7c34b0d42
commit cea1d8b4d1
4 changed files with 12 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
import pathlib
from utils.log import quick_log
from fastapi import APIRouter, HTTPException, Response, status as Status
from fastapi import APIRouter, HTTPException, Request, Response, status as Status
from pydantic import BaseModel
from utils.rwkv import *
from utils.torch import *
@@ -30,7 +31,7 @@ class SwitchModelBody(BaseModel):
@router.post("/switch-model")
def switch_model(body: SwitchModelBody, response: Response):
def switch_model(body: SwitchModelBody, response: Response, request: Request):
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = Status.HTTP_304_NOT_MODIFIED
return
@@ -53,6 +54,7 @@ def switch_model(body: SwitchModelBody, response: Response):
)
except Exception as e:
print(e)
quick_log(request, body, f"Exception: {e}")
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(Status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Response, status
from utils.log import quick_log
from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
@@ -72,7 +73,7 @@ class LongestPrefixStateBody(BaseModel):
@router.post("/longest-prefix-state")
def longest_prefix_state(body: LongestPrefixStateBody):
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -83,8 +84,10 @@ def longest_prefix_state(body: LongestPrefixStateBody):
if id != -1:
v = dtrie[id]
device = v["device"]
prompt = trie[id]
quick_log(request, body, "Hit: " + prompt)
return {
"prompt": trie[id],
"prompt": prompt,
"tokens": v["tokens"],
"state": [tensor.to(device) for tensor in v["state"]]
if device != torch.device("cpu")