backend api
This commit is contained in:
58
backend-python/routes/completion.py
Normal file
58
backend-python/routes/completion.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Request, status, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from pydantic import BaseModel
|
||||
from utils.rwkv import *
|
||||
import global_var
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class CompletionBody(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str
|
||||
stream: bool
|
||||
max_tokens: int
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/chat/completions")
|
||||
async def completions(body: CompletionBody, request: Request):
|
||||
model = global_var.get(global_var.Model)
|
||||
|
||||
question = body.messages[-1]
|
||||
if question.role == 'user':
|
||||
question = question.content
|
||||
else:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
|
||||
|
||||
completion_text = ""
|
||||
for message in body.messages:
|
||||
if message.role == 'user':
|
||||
completion_text += "Bob: " + message.content + "\n\n"
|
||||
elif message.role == 'assistant':
|
||||
completion_text += "Alice: " + message.content + "\n\n"
|
||||
completion_text += "Alice:"
|
||||
|
||||
async def eval_rwkv():
|
||||
if body.stream:
|
||||
for response, delta in rwkv_generate(model, completion_text):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
|
||||
yield "[DONE]"
|
||||
else:
|
||||
response = None
|
||||
for response, delta in rwkv_generate(model, completion_text):
|
||||
pass
|
||||
yield json.dumps({"response": response, "model": "rwkv"})
|
||||
# torch_gc()
|
||||
|
||||
return EventSourceResponse(eval_rwkv())
|
||||
46
backend-python/routes/config.py
Normal file
46
backend-python/routes/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from langchain.llms import RWKV
|
||||
from utils.rwkv import *
|
||||
from utils.torch import *
|
||||
import global_var
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UpdateConfigBody(BaseModel):
|
||||
model: str = None
|
||||
strategy: str = None
|
||||
max_response_token: int = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
presence_penalty: float = None
|
||||
count_penalty: float = None
|
||||
|
||||
|
||||
@router.post("/update-config")
|
||||
def update_config(body: UpdateConfigBody):
|
||||
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
|
||||
return "loading"
|
||||
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||
global_var.set(global_var.Model, None)
|
||||
torch_gc()
|
||||
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||
try:
|
||||
global_var.set(global_var.Model, RWKV(
|
||||
model=sys.argv[2],
|
||||
strategy=sys.argv[1],
|
||||
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
||||
))
|
||||
except Exception:
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
|
||||
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
||||
|
||||
return "success"
|
||||
Reference in New Issue
Block a user