RWKV-Runner/backend-python/routes/completion.py

483 lines
17 KiB
Python
Raw Normal View History

2023-05-17 11:39:00 +08:00
import asyncio
2023-05-07 17:27:54 +08:00
import json
2023-05-17 11:39:00 +08:00
from threading import Lock
2023-07-17 12:59:45 +08:00
from typing import List, Union
2023-08-04 22:27:55 +08:00
from enum import Enum
import base64
2023-05-07 17:27:54 +08:00
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
2023-08-04 22:27:55 +08:00
from pydantic import BaseModel, Field
import numpy as np
import tiktoken
2023-05-07 17:27:54 +08:00
from utils.rwkv import *
2023-06-03 17:12:59 +08:00
from utils.log import quick_log
2023-05-07 17:27:54 +08:00
import global_var
router = APIRouter()
2023-08-04 22:27:55 +08:00
class Role(Enum):
User = "user"
Assistant = "assistant"
System = "system"
2023-05-07 17:27:54 +08:00
class Message(BaseModel):
2023-08-04 22:27:55 +08:00
role: Role
content: str = Field(min_length=1)
2023-05-07 17:27:54 +08:00
2023-05-22 11:18:37 +08:00
class ChatCompletionBody(ModelConfigBody):
2023-08-04 22:22:59 +08:00
messages: Union[List[Message], None]
2023-05-17 11:47:45 +08:00
model: str = "rwkv"
stream: bool = False
2023-08-04 22:22:59 +08:00
stop: Union[str, List[str], None] = [
2023-07-29 19:19:38 +08:00
"\n\nUser",
"\n\nQuestion",
"\n\nQ",
"\n\nHuman",
"\n\nBob",
]
2023-08-04 22:22:59 +08:00
user_name: Union[str, None] = None
assistant_name: Union[str, None] = None
2023-05-17 11:39:00 +08:00
2023-06-15 21:52:22 +08:00
class Config:
schema_extra = {
"example": {
2023-08-04 22:27:55 +08:00
"messages": [{"role": Role.User.value, "content": "hello"}],
2023-06-15 21:52:22 +08:00
"model": "rwkv",
"stream": False,
"stop": None,
"user_name": None,
"assistant_name": None,
2023-06-15 21:52:22 +08:00
"max_tokens": 1000,
"temperature": 1.2,
"top_p": 0.5,
"presence_penalty": 0.4,
"frequency_penalty": 0.4,
}
}
2023-05-17 11:39:00 +08:00
2023-06-18 20:16:52 +08:00
class CompletionBody(ModelConfigBody):
2023-08-04 22:22:59 +08:00
prompt: Union[str, List[str], None]
2023-06-18 20:16:52 +08:00
model: str = "rwkv"
stream: bool = False
2023-08-04 22:22:59 +08:00
stop: Union[str, List[str], None] = None
2023-06-18 20:16:52 +08:00
class Config:
schema_extra = {
"example": {
"prompt": "The following is an epic science fiction masterpiece that is immortalized, "
+ "with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n",
"model": "rwkv",
"stream": False,
"stop": None,
"max_tokens": 100,
"temperature": 1.2,
"top_p": 0.5,
"presence_penalty": 0.4,
"frequency_penalty": 0.4,
}
}
2023-05-17 11:39:00 +08:00
completion_lock = Lock()
2023-05-07 17:27:54 +08:00
2023-06-03 17:12:59 +08:00
requests_num = 0
2023-05-07 17:27:54 +08:00
2023-06-18 20:16:52 +08:00
async def eval_rwkv(
2023-07-25 16:09:31 +08:00
model: AbstractRWKV,
2023-06-18 20:16:52 +08:00
request: Request,
body: ModelConfigBody,
prompt: str,
stream: bool,
2023-08-04 22:22:59 +08:00
stop: Union[str, List[str], None],
2023-06-18 20:16:52 +08:00
chat_mode: bool,
):
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
2023-07-03 17:41:47 +08:00
with completion_lock:
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
response, prompt_tokens, completion_tokens = "", 0, 0
for response, delta, prompt_tokens, completion_tokens in model.generate(
prompt,
stop=stop,
):
if await request.is_disconnected():
break
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
# "response": response,
2023-07-03 17:41:47 +08:00
"model": model.name,
"choices": [
{
"delta": {"content": delta},
"index": 0,
"finish_reason": None,
}
if chat_mode
else {
"text": delta,
"index": 0,
"finish_reason": None,
}
],
}
)
# torch_gc()
2023-06-18 20:16:52 +08:00
requests_num = requests_num - 1
2023-07-03 17:41:47 +08:00
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
2023-06-18 20:16:52 +08:00
quick_log(
request,
2023-07-03 17:41:47 +08:00
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
2023-06-18 20:16:52 +08:00
)
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
# "response": response,
2023-06-19 22:30:49 +08:00
"model": model.name,
2023-06-18 20:16:52 +08:00
"choices": [
{
2023-07-03 17:41:47 +08:00
"delta": {},
2023-06-18 20:16:52 +08:00
"index": 0,
2023-07-03 17:41:47 +08:00
"finish_reason": "stop",
2023-06-18 20:16:52 +08:00
}
if chat_mode
else {
2023-07-03 17:41:47 +08:00
"text": "",
2023-06-18 20:16:52 +08:00
"index": 0,
2023-07-03 17:41:47 +08:00
"finish_reason": "stop",
2023-06-18 20:16:52 +08:00
}
],
}
)
2023-07-03 17:41:47 +08:00
yield "[DONE]"
else:
yield {
"object": "chat.completion" if chat_mode else "text_completion",
# "response": response,
2023-06-19 22:30:49 +08:00
"model": model.name,
2023-07-03 17:41:47 +08:00
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
2023-06-18 20:16:52 +08:00
"choices": [
{
2023-07-03 17:41:47 +08:00
"message": {
2023-08-04 22:27:55 +08:00
"role": Role.Assistant.value,
2023-07-03 17:41:47 +08:00
"content": response,
},
2023-06-18 20:16:52 +08:00
"index": 0,
"finish_reason": "stop",
}
if chat_mode
else {
2023-07-03 17:41:47 +08:00
"text": response,
2023-06-18 20:16:52 +08:00
"index": 0,
"finish_reason": "stop",
}
],
}
2023-07-26 22:24:26 +08:00
@router.post("/v1/chat/completions", tags=["Completions"])
@router.post("/chat/completions", tags=["Completions"])
2023-05-22 11:18:37 +08:00
async def chat_completions(body: ChatCompletionBody, request: Request):
2023-07-25 16:09:31 +08:00
model: TextRWKV = global_var.get(global_var.Model)
2023-05-17 11:39:00 +08:00
if model is None:
2023-05-07 22:48:52 +08:00
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
2023-05-07 17:27:54 +08:00
2023-08-04 22:27:55 +08:00
if body.messages is None or body.messages == []:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "messages not found")
basic_system: str = ""
if body.messages[0].role == Role.System:
basic_system = body.messages[0].content
2023-05-07 17:27:54 +08:00
2023-05-28 12:53:14 +08:00
interface = model.interface
user = model.user if body.user_name is None else body.user_name
bot = model.bot if body.assistant_name is None else body.assistant_name
is_raven = model.rwkv_type == RWKVType.Raven
2023-05-28 12:53:14 +08:00
2023-08-04 22:27:55 +08:00
completion_text: str = ""
if basic_system == "":
completion_text = (
f"""
2023-05-24 14:01:22 +08:00
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
2023-08-04 22:27:55 +08:00
if is_raven
else (
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
)
elif basic_system != "":
completion_text = (
(
2023-05-24 14:01:22 +08:00
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
if is_raven
2023-05-31 12:46:06 +08:00
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
2023-05-21 23:25:58 +08:00
)
2023-08-04 22:27:55 +08:00
+ basic_system.replace("\r\n", "\n")
.replace("\r", "\n")
.replace("\n\n", "\n")
.replace("\n", " ")
.strip()
.replace("You are", f"{bot} is" if is_raven else "I am")
.replace("you are", f"{bot} is" if is_raven else "I am")
.replace("You're", f"{bot} is" if is_raven else "I'm")
.replace("you're", f"{bot} is" if is_raven else "I'm")
.replace("You", f"{bot}" if is_raven else "I")
.replace("you", f"{bot}" if is_raven else "I")
.replace("Your", f"{bot}'s" if is_raven else "My")
.replace("your", f"{bot}'s" if is_raven else "my")
.replace("", f"{bot}" if is_raven else "")
+ "\n\n"
)
for message in body.messages[(0 if basic_system == "" else 1) :]:
append_message: str = ""
if message.role == Role.User:
append_message = f"{user}{interface} " + message.content
elif message.role == Role.Assistant:
append_message = f"{bot}{interface} " + message.content
elif message.role == Role.System:
append_message = message.content
completion_text += (
append_message.replace("\r\n", "\n")
.replace("\r", "\n")
.replace("\n\n", "\n")
.strip()
+ "\n\n"
)
2023-05-24 14:01:22 +08:00
completion_text += f"{bot}{interface}"
2023-05-07 17:27:54 +08:00
if type(body.stop) == str:
body.stop = [body.stop, f"\n\n{user}", f"\n\n{bot}"]
else:
body.stop.append(f"\n\n{user}")
body.stop.append(f"\n\n{bot}")
2023-05-17 11:39:00 +08:00
if body.stream:
2023-06-18 20:16:52 +08:00
return EventSourceResponse(
2023-07-29 19:19:38 +08:00
eval_rwkv(
model, request, body, completion_text, body.stream, body.stop, True
)
2023-06-18 20:16:52 +08:00
)
2023-05-17 11:39:00 +08:00
else:
2023-06-03 17:12:59 +08:00
try:
2023-06-18 20:16:52 +08:00
return await eval_rwkv(
2023-07-29 19:19:38 +08:00
model, request, body, completion_text, body.stream, body.stop, True
2023-06-18 20:16:52 +08:00
).__anext__()
2023-06-03 17:12:59 +08:00
except StopAsyncIteration:
return None
2023-05-22 11:18:37 +08:00
2023-07-26 22:24:26 +08:00
@router.post("/v1/completions", tags=["Completions"])
@router.post("/completions", tags=["Completions"])
2023-05-22 11:18:37 +08:00
async def completions(body: CompletionBody, request: Request):
2023-07-25 16:09:31 +08:00
model: AbstractRWKV = global_var.get(global_var.Model)
2023-05-22 11:18:37 +08:00
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
2023-05-28 12:53:14 +08:00
if body.prompt is None or body.prompt == "" or body.prompt == []:
2023-05-27 15:18:12 +08:00
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
2023-05-22 11:18:37 +08:00
if type(body.prompt) == list:
body.prompt = body.prompt[0] # TODO: support multiple prompts
2023-05-22 11:18:37 +08:00
if body.stream:
2023-06-18 20:16:52 +08:00
return EventSourceResponse(
eval_rwkv(model, request, body, body.prompt, body.stream, body.stop, False)
)
2023-05-22 11:18:37 +08:00
else:
2023-06-03 17:12:59 +08:00
try:
2023-06-18 20:16:52 +08:00
return await eval_rwkv(
model, request, body, body.prompt, body.stream, body.stop, False
).__anext__()
2023-06-03 17:12:59 +08:00
except StopAsyncIteration:
return None
class EmbeddingsBody(BaseModel):
2023-08-04 22:22:59 +08:00
input: Union[str, List[str], List[List[int]], None]
model: str = "rwkv"
encoding_format: str = None
fast_mode: bool = False
class Config:
schema_extra = {
"example": {
"input": "a big apple",
"model": "rwkv",
"encoding_format": None,
"fast_mode": False,
}
}
def embedding_base64(embedding: List[float]) -> str:
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
2023-07-26 22:24:26 +08:00
@router.post("/v1/embeddings", tags=["Embeddings"])
@router.post("/embeddings", tags=["Embeddings"])
@router.post("/v1/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
@router.post("/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
async def embeddings(body: EmbeddingsBody, request: Request):
2023-07-25 16:09:31 +08:00
model: AbstractRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.input is None or body.input == "" or body.input == [] or body.input == [[]]:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found")
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
2023-07-03 17:41:47 +08:00
with completion_lock:
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
base64_format = False
if body.encoding_format == "base64":
base64_format = True
embeddings = []
prompt_tokens = 0
if type(body.input) == list:
if type(body.input[0]) == list:
encoding = tiktoken.model.encoding_for_model(
"text-embedding-ada-002"
2023-06-20 15:55:52 +08:00
)
2023-07-03 17:41:47 +08:00
for i in range(len(body.input)):
if await request.is_disconnected():
break
input = encoding.decode(body.input[i])
embedding, token_len = model.get_embedding(
input, body.fast_mode
)
prompt_tokens = prompt_tokens + token_len
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
for i in range(len(body.input)):
if await request.is_disconnected():
break
embedding, token_len = model.get_embedding(
body.input[i], body.fast_mode
)
prompt_tokens = prompt_tokens + token_len
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
embedding, prompt_tokens = model.get_embedding(
body.input, body.fast_mode
)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
2023-07-03 17:41:47 +08:00
requests_num = requests_num - 1
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
None,
"Stop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
None,
2023-07-03 17:41:47 +08:00
"Finished. RequestsNum: " + str(requests_num),
)
2023-07-03 17:41:47 +08:00
ret_data = [
{
"object": "embedding",
"index": i,
"embedding": embedding,
}
for i, embedding in enumerate(embeddings)
]
return {
"object": "list",
"data": ret_data,
"model": model.name,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
}