RWKV-Runner/backend-python/routes/completion.py

116 lines
3.8 KiB
Python
Raw Normal View History

2023-05-17 03:39:00 +00:00
import asyncio
2023-05-07 09:27:54 +00:00
import json
2023-05-17 03:39:00 +00:00
from threading import Lock
2023-05-07 09:27:54 +00:00
from typing import List
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from utils.rwkv import *
import global_var
router = APIRouter()
class Message(BaseModel):
role: str
content: str
2023-05-17 03:39:00 +00:00
class CompletionBody(ModelConfigBody):
2023-05-07 09:27:54 +00:00
messages: List[Message]
2023-05-17 03:47:45 +00:00
model: str = "rwkv"
stream: bool = False
2023-05-17 03:39:00 +00:00
completion_lock = Lock()
2023-05-07 09:27:54 +00:00
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def completions(body: CompletionBody, request: Request):
2023-05-17 03:39:00 +00:00
model: RWKV = global_var.get(global_var.Model)
if model is None:
2023-05-07 14:48:52 +00:00
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
2023-05-07 09:27:54 +00:00
question = body.messages[-1]
2023-05-17 03:39:00 +00:00
if question.role == "user":
2023-05-07 09:27:54 +00:00
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = ""
for message in body.messages:
2023-05-17 03:39:00 +00:00
if message.role == "user":
2023-05-07 09:27:54 +00:00
completion_text += "Bob: " + message.content + "\n\n"
2023-05-17 03:39:00 +00:00
elif message.role == "assistant":
2023-05-07 09:27:54 +00:00
completion_text += "Alice: " + message.content + "\n\n"
completion_text += "Alice:"
async def eval_rwkv():
2023-05-17 03:39:00 +00:00
while completion_lock.locked():
await asyncio.sleep(0.1)
2023-05-07 09:27:54 +00:00
else:
2023-05-17 03:39:00 +00:00
with completion_lock:
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
2023-05-19 06:22:37 +00:00
model, completion_text, stop="\n\nBob"
2023-05-17 03:39:00 +00:00
):
if await request.is_disconnected():
break
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {"content": delta},
"index": 0,
"finish_reason": None,
}
],
}
)
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {},
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
2023-05-19 06:22:37 +00:00
model, completion_text, stop="\n\nBob"
2023-05-17 03:39:00 +00:00
):
pass
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"message": {
"role": "assistant",
"content": response,
},
"index": 0,
"finish_reason": "stop",
}
],
}
# torch_gc()
2023-05-07 09:27:54 +00:00
2023-05-17 03:39:00 +00:00
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
return await eval_rwkv().__anext__()