2023-05-17 03:39:00 +00:00
|
|
|
import asyncio
|
2023-05-07 09:27:54 +00:00
|
|
|
import json
|
2023-05-17 03:39:00 +00:00
|
|
|
from threading import Lock
|
2023-07-17 04:59:45 +00:00
|
|
|
from typing import List, Union
|
2023-08-04 14:27:55 +00:00
|
|
|
from enum import Enum
|
2023-06-19 14:51:06 +00:00
|
|
|
import base64
|
2023-05-07 09:27:54 +00:00
|
|
|
|
|
|
|
from fastapi import APIRouter, Request, status, HTTPException
|
|
|
|
from sse_starlette.sse import EventSourceResponse
|
2023-08-04 14:27:55 +00:00
|
|
|
from pydantic import BaseModel, Field
|
2023-06-19 14:51:06 +00:00
|
|
|
import tiktoken
|
2023-05-07 09:27:54 +00:00
|
|
|
from utils.rwkv import *
|
2023-06-03 09:12:59 +00:00
|
|
|
from utils.log import quick_log
|
2023-05-07 09:27:54 +00:00
|
|
|
import global_var
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
2023-08-04 14:27:55 +00:00
|
|
|
class Role(Enum):
|
|
|
|
User = "user"
|
|
|
|
Assistant = "assistant"
|
|
|
|
System = "system"
|
|
|
|
|
|
|
|
|
2023-05-07 09:27:54 +00:00
|
|
|
class Message(BaseModel):
|
2023-08-04 14:27:55 +00:00
|
|
|
role: Role
|
2023-08-27 13:02:54 +00:00
|
|
|
content: str = Field(min_length=0)
|
2023-08-27 13:21:11 +00:00
|
|
|
raw: bool = Field(False, description="Whether to treat content as raw text")
|
|
|
|
|
|
|
|
|
|
|
|
default_stop = [
|
|
|
|
"\n\nUser",
|
|
|
|
"\n\nQuestion",
|
|
|
|
"\n\nQ",
|
|
|
|
"\n\nHuman",
|
|
|
|
"\n\nBob",
|
2023-11-17 13:18:52 +00:00
|
|
|
"\n\nAssistant",
|
|
|
|
"\n\nAnswer",
|
|
|
|
"\n\nA",
|
|
|
|
"\n\nBot",
|
|
|
|
"\n\nAlice",
|
2023-08-27 13:21:11 +00:00
|
|
|
]
|
2023-05-07 09:27:54 +00:00
|
|
|
|
|
|
|
|
2023-05-22 03:18:37 +00:00
|
|
|
class ChatCompletionBody(ModelConfigBody):
|
2023-08-04 14:22:59 +00:00
|
|
|
messages: Union[List[Message], None]
|
2023-09-15 08:31:14 +00:00
|
|
|
model: Union[str, None] = "rwkv"
|
2023-05-17 03:47:45 +00:00
|
|
|
stream: bool = False
|
2023-08-27 13:21:11 +00:00
|
|
|
stop: Union[str, List[str], None] = default_stop
|
2023-10-26 09:47:40 +00:00
|
|
|
user_name: Union[str, None] = Field(
|
|
|
|
None, description="Internal user name", min_length=1
|
|
|
|
)
|
2023-08-27 13:21:11 +00:00
|
|
|
assistant_name: Union[str, None] = Field(
|
2023-10-26 09:47:40 +00:00
|
|
|
None, description="Internal assistant name", min_length=1
|
2023-08-27 13:21:11 +00:00
|
|
|
)
|
2024-03-25 04:52:40 +00:00
|
|
|
system_name: Union[str, None] = Field(
|
|
|
|
None, description="Internal system name", min_length=1
|
|
|
|
)
|
2023-08-27 13:21:11 +00:00
|
|
|
presystem: bool = Field(
|
|
|
|
True, description="Whether to insert default system prompt at the beginning"
|
|
|
|
)
|
2023-05-17 03:39:00 +00:00
|
|
|
|
2023-11-07 14:53:11 +00:00
|
|
|
model_config = {
|
|
|
|
"json_schema_extra": {
|
2023-06-15 13:52:22 +00:00
|
|
|
"example": {
|
2023-08-27 13:21:11 +00:00
|
|
|
"messages": [
|
|
|
|
{"role": Role.User.value, "content": "hello", "raw": False}
|
|
|
|
],
|
2023-06-15 13:52:22 +00:00
|
|
|
"model": "rwkv",
|
|
|
|
"stream": False,
|
|
|
|
"stop": None,
|
2023-07-31 14:48:54 +00:00
|
|
|
"user_name": None,
|
|
|
|
"assistant_name": None,
|
2024-03-25 04:52:40 +00:00
|
|
|
"system_name": None,
|
2023-08-27 13:21:11 +00:00
|
|
|
"presystem": True,
|
2023-06-15 13:52:22 +00:00
|
|
|
"max_tokens": 1000,
|
2024-02-03 14:03:10 +00:00
|
|
|
"temperature": 1,
|
|
|
|
"top_p": 0.3,
|
|
|
|
"presence_penalty": 0,
|
|
|
|
"frequency_penalty": 1,
|
2023-06-15 13:52:22 +00:00
|
|
|
}
|
|
|
|
}
|
2023-11-07 14:53:11 +00:00
|
|
|
}
|
2023-06-15 13:52:22 +00:00
|
|
|
|
2023-05-17 03:39:00 +00:00
|
|
|
|
2023-06-18 12:16:52 +00:00
|
|
|
class CompletionBody(ModelConfigBody):
|
2023-08-04 14:22:59 +00:00
|
|
|
prompt: Union[str, List[str], None]
|
2023-09-15 08:31:14 +00:00
|
|
|
model: Union[str, None] = "rwkv"
|
2023-06-18 12:16:52 +00:00
|
|
|
stream: bool = False
|
2023-08-04 14:22:59 +00:00
|
|
|
stop: Union[str, List[str], None] = None
|
2023-06-18 12:16:52 +00:00
|
|
|
|
2023-11-07 14:53:11 +00:00
|
|
|
model_config = {
|
|
|
|
"json_schema_extra": {
|
2023-06-18 12:16:52 +00:00
|
|
|
"example": {
|
|
|
|
"prompt": "The following is an epic science fiction masterpiece that is immortalized, "
|
|
|
|
+ "with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n",
|
|
|
|
"model": "rwkv",
|
|
|
|
"stream": False,
|
|
|
|
"stop": None,
|
|
|
|
"max_tokens": 100,
|
2024-02-03 14:03:10 +00:00
|
|
|
"temperature": 1,
|
|
|
|
"top_p": 0.3,
|
|
|
|
"presence_penalty": 0,
|
|
|
|
"frequency_penalty": 1,
|
2023-06-18 12:16:52 +00:00
|
|
|
}
|
|
|
|
}
|
2023-11-07 14:53:11 +00:00
|
|
|
}
|
2023-06-18 12:16:52 +00:00
|
|
|
|
|
|
|
|
2023-05-17 03:39:00 +00:00
|
|
|
completion_lock = Lock()
|
2023-05-07 09:27:54 +00:00
|
|
|
|
2023-06-03 09:12:59 +00:00
|
|
|
requests_num = 0
|
|
|
|
|
2023-05-07 09:27:54 +00:00
|
|
|
|
2023-06-18 12:16:52 +00:00
|
|
|
async def eval_rwkv(
|
2023-07-25 08:09:31 +00:00
|
|
|
model: AbstractRWKV,
|
2023-06-18 12:16:52 +00:00
|
|
|
request: Request,
|
|
|
|
body: ModelConfigBody,
|
|
|
|
prompt: str,
|
|
|
|
stream: bool,
|
2023-08-04 14:22:59 +00:00
|
|
|
stop: Union[str, List[str], None],
|
2023-06-18 12:16:52 +00:00
|
|
|
chat_mode: bool,
|
|
|
|
):
|
|
|
|
global requests_num
|
|
|
|
requests_num = requests_num + 1
|
|
|
|
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
|
|
|
|
while completion_lock.locked():
|
|
|
|
if await request.is_disconnected():
|
|
|
|
requests_num = requests_num - 1
|
|
|
|
print(f"{request.client} Stop Waiting (Lock)")
|
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
None,
|
|
|
|
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
|
|
|
)
|
|
|
|
return
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
else:
|
2023-07-03 09:41:47 +00:00
|
|
|
with completion_lock:
|
|
|
|
if await request.is_disconnected():
|
|
|
|
requests_num = requests_num - 1
|
|
|
|
print(f"{request.client} Stop Waiting (Lock)")
|
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
None,
|
|
|
|
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
|
|
|
)
|
|
|
|
return
|
|
|
|
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
|
|
|
set_rwkv_config(model, body)
|
2024-02-05 14:27:02 +00:00
|
|
|
print(get_rwkv_config(model))
|
2023-07-03 09:41:47 +00:00
|
|
|
|
|
|
|
response, prompt_tokens, completion_tokens = "", 0, 0
|
|
|
|
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
|
|
|
prompt,
|
|
|
|
stop=stop,
|
|
|
|
):
|
|
|
|
if await request.is_disconnected():
|
|
|
|
break
|
|
|
|
if stream:
|
|
|
|
yield json.dumps(
|
|
|
|
{
|
2024-02-05 14:27:02 +00:00
|
|
|
"object": (
|
|
|
|
"chat.completion.chunk"
|
|
|
|
if chat_mode
|
|
|
|
else "text_completion"
|
|
|
|
),
|
2023-07-29 11:20:43 +00:00
|
|
|
# "response": response,
|
2023-07-03 09:41:47 +00:00
|
|
|
"model": model.name,
|
|
|
|
"choices": [
|
2024-02-05 14:27:02 +00:00
|
|
|
(
|
|
|
|
{
|
|
|
|
"delta": {"content": delta},
|
|
|
|
"index": 0,
|
|
|
|
"finish_reason": None,
|
|
|
|
}
|
|
|
|
if chat_mode
|
|
|
|
else {
|
|
|
|
"text": delta,
|
|
|
|
"index": 0,
|
|
|
|
"finish_reason": None,
|
|
|
|
}
|
|
|
|
)
|
2023-07-03 09:41:47 +00:00
|
|
|
],
|
|
|
|
}
|
|
|
|
)
|
|
|
|
# torch_gc()
|
2023-06-18 12:16:52 +00:00
|
|
|
requests_num = requests_num - 1
|
2023-07-03 09:41:47 +00:00
|
|
|
if await request.is_disconnected():
|
|
|
|
print(f"{request.client} Stop Waiting")
|
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
body,
|
|
|
|
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
|
|
|
|
)
|
|
|
|
return
|
2023-06-18 12:16:52 +00:00
|
|
|
quick_log(
|
|
|
|
request,
|
2023-07-03 09:41:47 +00:00
|
|
|
body,
|
|
|
|
response + "\nFinished. RequestsNum: " + str(requests_num),
|
2023-06-18 12:16:52 +00:00
|
|
|
)
|
|
|
|
if stream:
|
|
|
|
yield json.dumps(
|
|
|
|
{
|
2024-02-05 14:27:02 +00:00
|
|
|
"object": (
|
|
|
|
"chat.completion.chunk" if chat_mode else "text_completion"
|
|
|
|
),
|
2023-07-29 11:20:43 +00:00
|
|
|
# "response": response,
|
2023-06-19 14:30:49 +00:00
|
|
|
"model": model.name,
|
2023-06-18 12:16:52 +00:00
|
|
|
"choices": [
|
2024-02-05 14:27:02 +00:00
|
|
|
(
|
|
|
|
{
|
|
|
|
"delta": {},
|
|
|
|
"index": 0,
|
|
|
|
"finish_reason": "stop",
|
|
|
|
}
|
|
|
|
if chat_mode
|
|
|
|
else {
|
|
|
|
"text": "",
|
|
|
|
"index": 0,
|
|
|
|
"finish_reason": "stop",
|
|
|
|
}
|
|
|
|
)
|
2023-06-18 12:16:52 +00:00
|
|
|
],
|
|
|
|
}
|
|
|
|
)
|
2023-07-03 09:41:47 +00:00
|
|
|
yield "[DONE]"
|
|
|
|
else:
|
|
|
|
yield {
|
|
|
|
"object": "chat.completion" if chat_mode else "text_completion",
|
2023-07-29 11:20:43 +00:00
|
|
|
# "response": response,
|
2023-06-19 14:30:49 +00:00
|
|
|
"model": model.name,
|
2023-07-03 09:41:47 +00:00
|
|
|
"usage": {
|
|
|
|
"prompt_tokens": prompt_tokens,
|
|
|
|
"completion_tokens": completion_tokens,
|
|
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
|
|
},
|
2023-06-18 12:16:52 +00:00
|
|
|
"choices": [
|
2024-02-05 14:27:02 +00:00
|
|
|
(
|
|
|
|
{
|
|
|
|
"message": {
|
|
|
|
"role": Role.Assistant.value,
|
|
|
|
"content": response,
|
|
|
|
},
|
|
|
|
"index": 0,
|
|
|
|
"finish_reason": "stop",
|
|
|
|
}
|
|
|
|
if chat_mode
|
|
|
|
else {
|
|
|
|
"text": response,
|
|
|
|
"index": 0,
|
|
|
|
"finish_reason": "stop",
|
|
|
|
}
|
|
|
|
)
|
2023-06-18 12:16:52 +00:00
|
|
|
],
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-03-25 04:52:40 +00:00
|
|
|
def chat_template_old(
|
|
|
|
model: TextRWKV, body: ChatCompletionBody, interface: str, user: str, bot: str
|
|
|
|
):
|
2023-07-31 14:48:54 +00:00
|
|
|
is_raven = model.rwkv_type == RWKVType.Raven
|
2023-05-28 04:53:14 +00:00
|
|
|
|
2023-08-04 14:27:55 +00:00
|
|
|
completion_text: str = ""
|
2023-08-27 13:21:11 +00:00
|
|
|
basic_system: Union[str, None] = None
|
|
|
|
if body.presystem:
|
|
|
|
if body.messages[0].role == Role.System:
|
|
|
|
basic_system = body.messages[0].content
|
|
|
|
|
|
|
|
if basic_system is None:
|
|
|
|
completion_text = (
|
|
|
|
f"""
|
2023-05-24 06:01:22 +00:00
|
|
|
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
|
|
|
|
{bot} is very intelligent, creative and friendly. \
|
|
|
|
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
|
|
|
|
{bot} likes to tell {user} a lot about herself and her opinions. \
|
|
|
|
{bot} usually gives {user} kind, helpful and informative advices.\n
|
|
|
|
"""
|
2023-07-31 14:48:54 +00:00
|
|
|
if is_raven
|
2023-08-27 13:21:11 +00:00
|
|
|
else (
|
|
|
|
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
|
|
|
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
if not body.messages[0].raw:
|
|
|
|
basic_system = (
|
|
|
|
basic_system.replace("\r\n", "\n")
|
|
|
|
.replace("\r", "\n")
|
|
|
|
.replace("\n\n", "\n")
|
|
|
|
.replace("\n", " ")
|
|
|
|
.strip()
|
|
|
|
)
|
|
|
|
completion_text = (
|
|
|
|
(
|
|
|
|
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
|
|
|
|
if is_raven
|
|
|
|
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
|
|
|
)
|
|
|
|
+ basic_system.replace("You are", f"{bot} is" if is_raven else "I am")
|
|
|
|
.replace("you are", f"{bot} is" if is_raven else "I am")
|
|
|
|
.replace("You're", f"{bot} is" if is_raven else "I'm")
|
|
|
|
.replace("you're", f"{bot} is" if is_raven else "I'm")
|
|
|
|
.replace("You", f"{bot}" if is_raven else "I")
|
|
|
|
.replace("you", f"{bot}" if is_raven else "I")
|
|
|
|
.replace("Your", f"{bot}'s" if is_raven else "My")
|
|
|
|
.replace("your", f"{bot}'s" if is_raven else "my")
|
|
|
|
.replace("你", f"{bot}" if is_raven else "我")
|
|
|
|
+ "\n\n"
|
2023-05-21 15:25:58 +00:00
|
|
|
)
|
2023-08-04 14:27:55 +00:00
|
|
|
|
2023-08-27 13:21:11 +00:00
|
|
|
for message in body.messages[(0 if basic_system is None else 1) :]:
|
2023-08-04 14:27:55 +00:00
|
|
|
append_message: str = ""
|
|
|
|
if message.role == Role.User:
|
|
|
|
append_message = f"{user}{interface} " + message.content
|
|
|
|
elif message.role == Role.Assistant:
|
|
|
|
append_message = f"{bot}{interface} " + message.content
|
|
|
|
elif message.role == Role.System:
|
|
|
|
append_message = message.content
|
2023-08-27 13:21:11 +00:00
|
|
|
if not message.raw:
|
|
|
|
append_message = (
|
|
|
|
append_message.replace("\r\n", "\n")
|
|
|
|
.replace("\r", "\n")
|
|
|
|
.replace("\n\n", "\n")
|
|
|
|
.strip()
|
|
|
|
)
|
|
|
|
completion_text += append_message + "\n\n"
|
2023-05-24 06:01:22 +00:00
|
|
|
completion_text += f"{bot}{interface}"
|
2023-05-07 09:27:54 +00:00
|
|
|
|
2024-03-25 04:52:40 +00:00
|
|
|
return completion_text
|
|
|
|
|
|
|
|
|
|
|
|
def chat_template(
|
|
|
|
model: TextRWKV, body: ChatCompletionBody, interface: str, user: str, bot: str
|
|
|
|
):
|
|
|
|
completion_text: str = ""
|
|
|
|
if body.presystem:
|
|
|
|
completion_text = (
|
|
|
|
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
|
|
|
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
|
|
|
)
|
|
|
|
|
|
|
|
system = "System" if body.system_name is None else body.system_name
|
|
|
|
for message in body.messages:
|
|
|
|
append_message: str = ""
|
|
|
|
if message.role == Role.User:
|
|
|
|
append_message = f"{user}{interface} " + message.content
|
|
|
|
elif message.role == Role.Assistant:
|
|
|
|
append_message = f"{bot}{interface} " + message.content
|
|
|
|
elif message.role == Role.System:
|
|
|
|
append_message = f"{system}{interface} " + message.content
|
|
|
|
completion_text += append_message + "\n\n"
|
|
|
|
completion_text += f"{bot}{interface}"
|
|
|
|
|
|
|
|
return completion_text
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/v1/chat/completions", tags=["Completions"])
|
|
|
|
@router.post("/chat/completions", tags=["Completions"])
|
|
|
|
async def chat_completions(body: ChatCompletionBody, request: Request):
|
|
|
|
model: TextRWKV = global_var.get(global_var.Model)
|
|
|
|
if model is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
|
|
|
|
|
|
|
if body.messages is None or body.messages == []:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "messages not found")
|
|
|
|
|
|
|
|
interface = model.interface
|
|
|
|
user = model.user if body.user_name is None else body.user_name
|
|
|
|
bot = model.bot if body.assistant_name is None else body.assistant_name
|
|
|
|
|
|
|
|
if model.version < 5:
|
|
|
|
completion_text = chat_template_old(model, body, interface, user, bot)
|
|
|
|
else:
|
|
|
|
completion_text = chat_template(model, body, interface, user, bot)
|
|
|
|
|
2023-10-26 09:47:40 +00:00
|
|
|
user_code = model.pipeline.decode([model.pipeline.encode(user)[0]])
|
|
|
|
bot_code = model.pipeline.decode([model.pipeline.encode(bot)[0]])
|
2023-07-31 14:48:54 +00:00
|
|
|
if type(body.stop) == str:
|
2023-10-26 09:47:40 +00:00
|
|
|
body.stop = [body.stop, f"\n\n{user_code}", f"\n\n{bot_code}"]
|
2023-08-27 13:21:11 +00:00
|
|
|
elif type(body.stop) == list:
|
2023-10-26 09:47:40 +00:00
|
|
|
body.stop.append(f"\n\n{user_code}")
|
|
|
|
body.stop.append(f"\n\n{bot_code}")
|
2023-08-27 13:21:11 +00:00
|
|
|
elif body.stop is None:
|
|
|
|
body.stop = default_stop
|
2024-03-25 04:52:40 +00:00
|
|
|
# if not body.presystem:
|
|
|
|
# body.stop.append("\n\n")
|
2023-07-31 14:48:54 +00:00
|
|
|
|
2023-05-17 03:39:00 +00:00
|
|
|
if body.stream:
|
2023-06-18 12:16:52 +00:00
|
|
|
return EventSourceResponse(
|
2023-07-29 11:19:38 +00:00
|
|
|
eval_rwkv(
|
|
|
|
model, request, body, completion_text, body.stream, body.stop, True
|
|
|
|
)
|
2023-06-18 12:16:52 +00:00
|
|
|
)
|
2023-05-17 03:39:00 +00:00
|
|
|
else:
|
2023-06-03 09:12:59 +00:00
|
|
|
try:
|
2023-06-18 12:16:52 +00:00
|
|
|
return await eval_rwkv(
|
2023-07-29 11:19:38 +00:00
|
|
|
model, request, body, completion_text, body.stream, body.stop, True
|
2023-06-18 12:16:52 +00:00
|
|
|
).__anext__()
|
2023-06-03 09:12:59 +00:00
|
|
|
except StopAsyncIteration:
|
|
|
|
return None
|
2023-05-22 03:18:37 +00:00
|
|
|
|
|
|
|
|
2023-07-26 14:24:26 +00:00
|
|
|
@router.post("/v1/completions", tags=["Completions"])
|
|
|
|
@router.post("/completions", tags=["Completions"])
|
2023-05-22 03:18:37 +00:00
|
|
|
async def completions(body: CompletionBody, request: Request):
|
2023-07-25 08:09:31 +00:00
|
|
|
model: AbstractRWKV = global_var.get(global_var.Model)
|
2023-05-22 03:18:37 +00:00
|
|
|
if model is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
2023-05-28 04:53:14 +00:00
|
|
|
|
2023-07-10 12:45:08 +00:00
|
|
|
if body.prompt is None or body.prompt == "" or body.prompt == []:
|
2023-05-27 07:18:12 +00:00
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
|
2023-05-22 03:18:37 +00:00
|
|
|
|
2023-07-10 12:45:08 +00:00
|
|
|
if type(body.prompt) == list:
|
|
|
|
body.prompt = body.prompt[0] # TODO: support multiple prompts
|
|
|
|
|
2023-05-22 03:18:37 +00:00
|
|
|
if body.stream:
|
2023-06-18 12:16:52 +00:00
|
|
|
return EventSourceResponse(
|
|
|
|
eval_rwkv(model, request, body, body.prompt, body.stream, body.stop, False)
|
|
|
|
)
|
2023-05-22 03:18:37 +00:00
|
|
|
else:
|
2023-06-03 09:12:59 +00:00
|
|
|
try:
|
2023-06-18 12:16:52 +00:00
|
|
|
return await eval_rwkv(
|
|
|
|
model, request, body, body.prompt, body.stream, body.stop, False
|
|
|
|
).__anext__()
|
2023-06-03 09:12:59 +00:00
|
|
|
except StopAsyncIteration:
|
|
|
|
return None
|
2023-06-19 14:51:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingsBody(BaseModel):
|
2023-08-04 14:22:59 +00:00
|
|
|
input: Union[str, List[str], List[List[int]], None]
|
2023-09-15 08:31:14 +00:00
|
|
|
model: Union[str, None] = "rwkv"
|
2023-06-19 14:51:06 +00:00
|
|
|
encoding_format: str = None
|
|
|
|
fast_mode: bool = False
|
|
|
|
|
2023-11-07 14:53:11 +00:00
|
|
|
model_config = {
|
|
|
|
"json_schema_extra": {
|
2023-06-19 14:51:06 +00:00
|
|
|
"example": {
|
|
|
|
"input": "a big apple",
|
|
|
|
"model": "rwkv",
|
|
|
|
"encoding_format": None,
|
|
|
|
"fast_mode": False,
|
|
|
|
}
|
|
|
|
}
|
2023-11-07 14:53:11 +00:00
|
|
|
}
|
2023-06-19 14:51:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
def embedding_base64(embedding: List[float]) -> str:
|
2023-12-14 10:37:07 +00:00
|
|
|
import numpy as np
|
|
|
|
|
2023-06-19 14:51:06 +00:00
|
|
|
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
|
|
|
|
|
|
|
|
|
2023-07-26 14:24:26 +00:00
|
|
|
@router.post("/v1/embeddings", tags=["Embeddings"])
|
|
|
|
@router.post("/embeddings", tags=["Embeddings"])
|
|
|
|
@router.post("/v1/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
|
|
|
|
@router.post("/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
|
2023-06-19 14:51:06 +00:00
|
|
|
async def embeddings(body: EmbeddingsBody, request: Request):
|
2023-07-25 08:09:31 +00:00
|
|
|
model: AbstractRWKV = global_var.get(global_var.Model)
|
2023-06-19 14:51:06 +00:00
|
|
|
if model is None:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
|
|
|
|
|
|
|
if body.input is None or body.input == "" or body.input == [] or body.input == [[]]:
|
|
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found")
|
|
|
|
|
|
|
|
global requests_num
|
|
|
|
requests_num = requests_num + 1
|
|
|
|
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
|
|
|
|
while completion_lock.locked():
|
|
|
|
if await request.is_disconnected():
|
|
|
|
requests_num = requests_num - 1
|
|
|
|
print(f"{request.client} Stop Waiting (Lock)")
|
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
None,
|
|
|
|
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
|
|
|
)
|
|
|
|
return
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
else:
|
2023-07-03 09:41:47 +00:00
|
|
|
with completion_lock:
|
|
|
|
if await request.is_disconnected():
|
|
|
|
requests_num = requests_num - 1
|
|
|
|
print(f"{request.client} Stop Waiting (Lock)")
|
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
None,
|
|
|
|
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
base64_format = False
|
|
|
|
if body.encoding_format == "base64":
|
|
|
|
base64_format = True
|
|
|
|
|
|
|
|
embeddings = []
|
|
|
|
prompt_tokens = 0
|
|
|
|
if type(body.input) == list:
|
|
|
|
if type(body.input[0]) == list:
|
|
|
|
encoding = tiktoken.model.encoding_for_model(
|
|
|
|
"text-embedding-ada-002"
|
2023-06-20 07:55:52 +00:00
|
|
|
)
|
2023-07-03 09:41:47 +00:00
|
|
|
for i in range(len(body.input)):
|
|
|
|
if await request.is_disconnected():
|
|
|
|
break
|
|
|
|
input = encoding.decode(body.input[i])
|
|
|
|
embedding, token_len = model.get_embedding(
|
|
|
|
input, body.fast_mode
|
|
|
|
)
|
|
|
|
prompt_tokens = prompt_tokens + token_len
|
|
|
|
if base64_format:
|
|
|
|
embedding = embedding_base64(embedding)
|
|
|
|
embeddings.append(embedding)
|
|
|
|
else:
|
|
|
|
for i in range(len(body.input)):
|
|
|
|
if await request.is_disconnected():
|
|
|
|
break
|
|
|
|
embedding, token_len = model.get_embedding(
|
|
|
|
body.input[i], body.fast_mode
|
|
|
|
)
|
|
|
|
prompt_tokens = prompt_tokens + token_len
|
|
|
|
if base64_format:
|
|
|
|
embedding = embedding_base64(embedding)
|
|
|
|
embeddings.append(embedding)
|
|
|
|
else:
|
|
|
|
embedding, prompt_tokens = model.get_embedding(
|
|
|
|
body.input, body.fast_mode
|
|
|
|
)
|
|
|
|
if base64_format:
|
|
|
|
embedding = embedding_base64(embedding)
|
|
|
|
embeddings.append(embedding)
|
2023-06-19 14:51:06 +00:00
|
|
|
|
2023-07-03 09:41:47 +00:00
|
|
|
requests_num = requests_num - 1
|
|
|
|
if await request.is_disconnected():
|
|
|
|
print(f"{request.client} Stop Waiting")
|
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
None,
|
|
|
|
"Stop Waiting. RequestsNum: " + str(requests_num),
|
|
|
|
)
|
|
|
|
return
|
2023-06-19 14:51:06 +00:00
|
|
|
quick_log(
|
|
|
|
request,
|
|
|
|
None,
|
2023-07-03 09:41:47 +00:00
|
|
|
"Finished. RequestsNum: " + str(requests_num),
|
2023-06-19 14:51:06 +00:00
|
|
|
)
|
|
|
|
|
2023-07-03 09:41:47 +00:00
|
|
|
ret_data = [
|
|
|
|
{
|
|
|
|
"object": "embedding",
|
|
|
|
"index": i,
|
|
|
|
"embedding": embedding,
|
|
|
|
}
|
|
|
|
for i, embedding in enumerate(embeddings)
|
|
|
|
]
|
|
|
|
|
|
|
|
return {
|
|
|
|
"object": "list",
|
|
|
|
"data": ret_data,
|
|
|
|
"model": model.name,
|
|
|
|
"usage": {
|
|
|
|
"prompt_tokens": prompt_tokens,
|
|
|
|
"total_tokens": prompt_tokens,
|
|
|
|
},
|
2023-06-19 14:51:06 +00:00
|
|
|
}
|