backend api

This commit is contained in:
josc146
2023-05-07 17:27:54 +08:00
parent 4795514e8f
commit 0e852daf43
8 changed files with 188 additions and 102 deletions

View File

@@ -0,0 +1,58 @@
import json
from typing import List
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from utils.rwkv import *
import global_var
router = APIRouter()
class Message(BaseModel):
role: str
content: str
class CompletionBody(BaseModel):
messages: List[Message]
model: str
stream: bool
max_tokens: int
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def completions(body: CompletionBody, request: Request):
model = global_var.get(global_var.Model)
question = body.messages[-1]
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = ""
for message in body.messages:
if message.role == 'user':
completion_text += "Bob: " + message.content + "\n\n"
elif message.role == 'assistant':
completion_text += "Alice: " + message.content + "\n\n"
completion_text += "Alice:"
async def eval_rwkv():
if body.stream:
for response, delta in rwkv_generate(model, completion_text):
if await request.is_disconnected():
break
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(model, completion_text):
pass
yield json.dumps({"response": response, "model": "rwkv"})
# torch_gc()
return EventSourceResponse(eval_rwkv())

View File

@@ -0,0 +1,46 @@
import pathlib
import sys
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel
from langchain.llms import RWKV
from utils.rwkv import *
from utils.torch import *
import global_var
router = APIRouter()
class UpdateConfigBody(BaseModel):
model: str = None
strategy: str = None
max_response_token: int = None
temperature: float = None
top_p: float = None
presence_penalty: float = None
count_penalty: float = None
@router.post("/update-config")
def update_config(body: UpdateConfigBody):
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
return "loading"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
global_var.set(global_var.Model, RWKV(
model=sys.argv[2],
strategy=sys.argv[1],
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
))
except Exception:
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"