RWKV-Runner/backend-python/routes/completion.py

373 lines
14 KiB
Python
Raw Normal View History

2023-05-17 11:39:00 +08:00
import asyncio
2023-05-07 17:27:54 +08:00
import json
2023-05-17 11:39:00 +08:00
from threading import Lock
2023-05-07 17:27:54 +08:00
from typing import List
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from utils.rwkv import *
2023-06-03 17:12:59 +08:00
from utils.log import quick_log
2023-05-07 17:27:54 +08:00
import global_var
router = APIRouter()
class Message(BaseModel):
role: str
content: str
2023-05-22 11:18:37 +08:00
class ChatCompletionBody(ModelConfigBody):
2023-05-07 17:27:54 +08:00
messages: List[Message]
2023-05-17 11:47:45 +08:00
model: str = "rwkv"
stream: bool = False
2023-05-22 11:24:57 +08:00
stop: str = None
2023-05-17 11:39:00 +08:00
completion_lock = Lock()
2023-05-07 17:27:54 +08:00
2023-06-03 17:12:59 +08:00
requests_num = 0
2023-05-07 17:27:54 +08:00
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
2023-05-22 11:18:37 +08:00
async def chat_completions(body: ChatCompletionBody, request: Request):
2023-05-17 11:39:00 +08:00
model: RWKV = global_var.get(global_var.Model)
if model is None:
2023-05-07 22:48:52 +08:00
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
2023-05-07 17:27:54 +08:00
question = body.messages[-1]
2023-05-17 11:39:00 +08:00
if question.role == "user":
2023-05-07 17:27:54 +08:00
question = question.content
2023-05-29 22:26:22 +08:00
elif question.role == "system":
question = body.messages[-2]
if question.role == "user":
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
2023-05-07 17:27:54 +08:00
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
2023-05-28 12:53:14 +08:00
interface = model.interface
user = model.user
bot = model.bot
completion_text = (
f"""
2023-05-24 14:01:22 +08:00
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
2023-05-28 12:53:14 +08:00
if user == "Bob"
2023-05-31 12:46:06 +08:00
else f"{user}{interface} hi\n\n{bot}{interface} Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
2023-05-28 12:53:14 +08:00
)
2023-05-07 17:27:54 +08:00
for message in body.messages:
2023-05-24 14:01:22 +08:00
if message.role == "system":
completion_text = (
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
2023-05-28 12:53:14 +08:00
if user == "Bob"
2023-05-31 12:46:06 +08:00
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
2023-05-24 14:01:22 +08:00
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
.replace("\n", " ")
.strip()
2023-05-31 12:46:06 +08:00
.replace("You are", f"{bot} is" if user == "Bob" else "I am")
.replace("you are", f"{bot} is" if user == "Bob" else "I am")
.replace("You're", f"{bot} is" if user == "Bob" else "I'm")
.replace("you're", f"{bot} is" if user == "Bob" else "I'm")
.replace("You", f"{bot}" if user == "Bob" else "I")
.replace("you", f"{bot}" if user == "Bob" else "I")
.replace("Your", f"{bot}'s" if user == "Bob" else "My")
.replace("your", f"{bot}'s" if user == "Bob" else "my")
.replace("", f"{bot}" if user == "Bob" else "")
2023-05-24 14:01:22 +08:00
+ "\n\n"
)
2023-05-29 22:26:22 +08:00
break
for message in body.messages:
if message.role == "user":
2023-05-21 23:25:58 +08:00
completion_text += (
2023-05-24 14:01:22 +08:00
f"{user}{interface} "
2023-05-21 23:25:58 +08:00
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
.strip()
+ "\n\n"
)
2023-05-17 11:39:00 +08:00
elif message.role == "assistant":
2023-05-21 23:25:58 +08:00
completion_text += (
2023-05-24 14:01:22 +08:00
f"{bot}{interface} "
2023-05-21 23:25:58 +08:00
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
.strip()
+ "\n\n"
)
2023-05-24 14:01:22 +08:00
completion_text += f"{bot}{interface}"
2023-05-07 17:27:54 +08:00
async def eval_rwkv():
2023-06-03 17:12:59 +08:00
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
2023-05-17 11:39:00 +08:00
while completion_lock.locked():
2023-05-27 15:18:12 +08:00
if await request.is_disconnected():
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
print(f"{request.client} Stop Waiting (Lock)")
2023-06-03 17:12:59 +08:00
quick_log(
2023-06-03 17:36:50 +08:00
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
2023-06-03 17:12:59 +08:00
)
2023-05-27 15:18:12 +08:00
return
2023-05-17 11:39:00 +08:00
await asyncio.sleep(0.1)
2023-05-07 17:27:54 +08:00
else:
2023-05-21 13:46:54 +08:00
completion_lock.acquire()
2023-06-03 19:28:37 +08:00
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
2023-05-21 13:46:54 +08:00
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
response = ""
2023-05-28 12:53:14 +08:00
for response, delta in model.generate(
2023-05-22 11:24:57 +08:00
completion_text,
2023-05-24 14:01:22 +08:00
stop=f"\n\n{user}" if body.stop is None else body.stop,
2023-05-21 13:46:54 +08:00
):
if await request.is_disconnected():
break
2023-05-17 11:39:00 +08:00
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
2023-05-21 13:46:54 +08:00
"delta": {"content": delta},
2023-05-17 11:39:00 +08:00
"index": 0,
2023-05-21 13:46:54 +08:00
"finish_reason": None,
2023-05-17 11:39:00 +08:00
}
],
}
)
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-05-24 11:45:55 +08:00
completion_lock.release()
2023-05-21 13:46:54 +08:00
if await request.is_disconnected():
2023-06-03 17:12:59 +08:00
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
2023-05-21 13:46:54 +08:00
return
2023-06-03 17:12:59 +08:00
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
2023-05-21 13:46:54 +08:00
yield json.dumps(
{
2023-05-17 11:39:00 +08:00
"response": response,
"model": "rwkv",
"choices": [
{
2023-05-21 13:46:54 +08:00
"delta": {},
2023-05-17 11:39:00 +08:00
"index": 0,
"finish_reason": "stop",
}
],
}
2023-05-21 13:46:54 +08:00
)
yield "[DONE]"
else:
response = ""
2023-05-28 12:53:14 +08:00
for response, delta in model.generate(
2023-05-22 11:24:57 +08:00
completion_text,
2023-05-24 14:01:22 +08:00
stop=f"\n\n{user}" if body.stop is None else body.stop,
2023-05-21 13:46:54 +08:00
):
if await request.is_disconnected():
break
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
2023-06-03 17:12:59 +08:00
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
2023-05-21 13:46:54 +08:00
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"message": {
"role": "assistant",
"content": response,
},
"index": 0,
"finish_reason": "stop",
}
],
}
2023-05-07 17:27:54 +08:00
2023-05-17 11:39:00 +08:00
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
2023-06-03 17:12:59 +08:00
try:
return await eval_rwkv().__anext__()
except StopAsyncIteration:
return None
2023-05-22 11:18:37 +08:00
class CompletionBody(ModelConfigBody):
prompt: str
model: str = "rwkv"
stream: bool = False
stop: str = None
@router.post("/v1/completions")
@router.post("/completions")
async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
2023-05-28 12:53:14 +08:00
2023-05-27 15:18:12 +08:00
if body.prompt is None or body.prompt == "":
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
2023-05-22 11:18:37 +08:00
async def eval_rwkv():
2023-06-03 17:12:59 +08:00
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
2023-05-22 11:18:37 +08:00
while completion_lock.locked():
2023-05-27 15:18:12 +08:00
if await request.is_disconnected():
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
print(f"{request.client} Stop Waiting (Lock)")
2023-06-03 17:12:59 +08:00
quick_log(
2023-06-03 17:36:50 +08:00
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
2023-06-03 17:12:59 +08:00
)
2023-05-27 15:18:12 +08:00
return
2023-05-22 11:18:37 +08:00
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
2023-06-03 19:28:37 +08:00
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
2023-05-22 11:18:37 +08:00
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
response = ""
2023-05-28 12:53:14 +08:00
for response, delta in model.generate(body.prompt, stop=body.stop):
2023-05-22 11:18:37 +08:00
if await request.is_disconnected():
break
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"text": delta,
"index": 0,
"finish_reason": None,
}
],
}
)
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-05-24 11:45:55 +08:00
completion_lock.release()
2023-05-22 11:18:37 +08:00
if await request.is_disconnected():
2023-06-03 17:12:59 +08:00
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
2023-05-22 11:18:37 +08:00
return
2023-06-03 17:12:59 +08:00
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
2023-05-22 11:18:37 +08:00
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"text": "",
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
response = ""
2023-05-28 12:53:14 +08:00
for response, delta in model.generate(body.prompt, stop=body.stop):
2023-05-22 11:18:37 +08:00
if await request.is_disconnected():
break
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
2023-06-03 17:12:59 +08:00
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
2023-05-22 11:18:37 +08:00
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"text": response,
"index": 0,
"finish_reason": "stop",
}
],
}
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
2023-06-03 17:12:59 +08:00
try:
return await eval_rwkv().__anext__()
except StopAsyncIteration:
return None