2023-05-17 11:39:00 +08:00
import asyncio
2023-05-07 17:27:54 +08:00
import json
2023-05-17 11:39:00 +08:00
from threading import Lock
2023-05-07 17:27:54 +08:00
from typing import List
from fastapi import APIRouter , Request , status , HTTPException
from sse_starlette . sse import EventSourceResponse
from pydantic import BaseModel
from utils . rwkv import *
2023-06-03 17:12:59 +08:00
from utils . log import quick_log
2023-05-07 17:27:54 +08:00
import global_var
router = APIRouter ( )
class Message ( BaseModel ) :
role : str
content : str
2023-05-22 11:18:37 +08:00
class ChatCompletionBody ( ModelConfigBody ) :
2023-05-07 17:27:54 +08:00
messages : List [ Message ]
2023-05-17 11:47:45 +08:00
model : str = " rwkv "
stream : bool = False
2023-05-22 11:24:57 +08:00
stop : str = None
2023-05-17 11:39:00 +08:00
2023-06-15 21:52:22 +08:00
class Config :
schema_extra = {
" example " : {
" messages " : [ { " role " : " user " , " content " : " hello " } ] ,
" model " : " rwkv " ,
" stream " : False ,
" stop " : None ,
" max_tokens " : 1000 ,
" temperature " : 1.2 ,
" top_p " : 0.5 ,
" presence_penalty " : 0.4 ,
" frequency_penalty " : 0.4 ,
}
}
2023-05-17 11:39:00 +08:00
completion_lock = Lock ( )
2023-05-07 17:27:54 +08:00
2023-06-03 17:12:59 +08:00
requests_num = 0
2023-05-07 17:27:54 +08:00
@router.post ( " /v1/chat/completions " )
@router.post ( " /chat/completions " )
2023-05-22 11:18:37 +08:00
async def chat_completions ( body : ChatCompletionBody , request : Request ) :
2023-05-17 11:39:00 +08:00
model : RWKV = global_var . get ( global_var . Model )
if model is None :
2023-05-07 22:48:52 +08:00
raise HTTPException ( status . HTTP_400_BAD_REQUEST , " model not loaded " )
2023-05-07 17:27:54 +08:00
question = body . messages [ - 1 ]
2023-05-17 11:39:00 +08:00
if question . role == " user " :
2023-05-07 17:27:54 +08:00
question = question . content
2023-05-29 22:26:22 +08:00
elif question . role == " system " :
question = body . messages [ - 2 ]
if question . role == " user " :
question = question . content
else :
raise HTTPException ( status . HTTP_400_BAD_REQUEST , " no question found " )
2023-05-07 17:27:54 +08:00
else :
raise HTTPException ( status . HTTP_400_BAD_REQUEST , " no question found " )
2023-05-28 12:53:14 +08:00
interface = model . interface
user = model . user
bot = model . bot
completion_text = (
f """
2023-05-24 14:01:22 +08:00
The following is a coherent verbose detailed conversation between a girl named { bot } and her friend { user } . \
{ bot } is very intelligent , creative and friendly . \
{ bot } is unlikely to disagree with { user } , and { bot } doesn ' t like to ask {user} questions. \
{ bot } likes to tell { user } a lot about herself and her opinions . \
{ bot } usually gives { user } kind , helpful and informative advices . \n
"""
2023-05-28 12:53:14 +08:00
if user == " Bob "
2023-05-31 12:46:06 +08:00
else f " { user } { interface } hi \n \n { bot } { interface } Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. \n \n "
2023-05-28 12:53:14 +08:00
)
2023-05-07 17:27:54 +08:00
for message in body . messages :
2023-05-24 14:01:22 +08:00
if message . role == " system " :
completion_text = (
f " The following is a coherent verbose detailed conversation between a girl named { bot } and her friend { user } . "
2023-05-28 12:53:14 +08:00
if user == " Bob "
2023-05-31 12:46:06 +08:00
else f " { user } { interface } hi \n \n { bot } { interface } Hi. "
2023-05-24 14:01:22 +08:00
+ message . content . replace ( " \\ n " , " \n " )
. replace ( " \r \n " , " \n " )
. replace ( " \n \n " , " \n " )
. replace ( " \n " , " " )
. strip ( )
2023-05-31 12:46:06 +08:00
. replace ( " You are " , f " { bot } is " if user == " Bob " else " I am " )
. replace ( " you are " , f " { bot } is " if user == " Bob " else " I am " )
. replace ( " You ' re " , f " { bot } is " if user == " Bob " else " I ' m " )
. replace ( " you ' re " , f " { bot } is " if user == " Bob " else " I ' m " )
. replace ( " You " , f " { bot } " if user == " Bob " else " I " )
. replace ( " you " , f " { bot } " if user == " Bob " else " I " )
. replace ( " Your " , f " { bot } ' s " if user == " Bob " else " My " )
. replace ( " your " , f " { bot } ' s " if user == " Bob " else " my " )
. replace ( " 你 " , f " { bot } " if user == " Bob " else " 我 " )
2023-05-24 14:01:22 +08:00
+ " \n \n "
)
2023-05-29 22:26:22 +08:00
break
for message in body . messages :
if message . role == " user " :
2023-05-21 23:25:58 +08:00
completion_text + = (
2023-05-24 14:01:22 +08:00
f " { user } { interface } "
2023-05-21 23:25:58 +08:00
+ message . content . replace ( " \\ n " , " \n " )
. replace ( " \r \n " , " \n " )
. replace ( " \n \n " , " \n " )
. strip ( )
+ " \n \n "
)
2023-05-17 11:39:00 +08:00
elif message . role == " assistant " :
2023-05-21 23:25:58 +08:00
completion_text + = (
2023-05-24 14:01:22 +08:00
f " { bot } { interface } "
2023-05-21 23:25:58 +08:00
+ message . content . replace ( " \\ n " , " \n " )
. replace ( " \r \n " , " \n " )
. replace ( " \n \n " , " \n " )
. strip ( )
+ " \n \n "
)
2023-05-24 14:01:22 +08:00
completion_text + = f " { bot } { interface } "
2023-05-07 17:27:54 +08:00
async def eval_rwkv ( ) :
2023-06-03 17:12:59 +08:00
global requests_num
requests_num = requests_num + 1
quick_log ( request , None , " Start Waiting. RequestsNum: " + str ( requests_num ) )
2023-05-17 11:39:00 +08:00
while completion_lock . locked ( ) :
2023-05-27 15:18:12 +08:00
if await request . is_disconnected ( ) :
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
print ( f " { request . client } Stop Waiting (Lock) " )
2023-06-03 17:12:59 +08:00
quick_log (
2023-06-03 17:36:50 +08:00
request ,
None ,
" Stop Waiting (Lock). RequestsNum: " + str ( requests_num ) ,
2023-06-03 17:12:59 +08:00
)
2023-05-27 15:18:12 +08:00
return
2023-05-17 11:39:00 +08:00
await asyncio . sleep ( 0.1 )
2023-05-07 17:27:54 +08:00
else :
2023-05-21 13:46:54 +08:00
completion_lock . acquire ( )
2023-06-03 19:28:37 +08:00
if await request . is_disconnected ( ) :
completion_lock . release ( )
requests_num = requests_num - 1
print ( f " { request . client } Stop Waiting (Lock) " )
quick_log (
request ,
None ,
" Stop Waiting (Lock). RequestsNum: " + str ( requests_num ) ,
)
return
2023-05-21 13:46:54 +08:00
set_rwkv_config ( model , global_var . get ( global_var . Model_Config ) )
set_rwkv_config ( model , body )
if body . stream :
2023-06-08 13:30:34 +08:00
response = " "
2023-05-28 12:53:14 +08:00
for response , delta in model . generate (
2023-05-22 11:24:57 +08:00
completion_text ,
2023-05-24 14:01:22 +08:00
stop = f " \n \n { user } " if body . stop is None else body . stop ,
2023-05-21 13:46:54 +08:00
) :
if await request . is_disconnected ( ) :
break
2023-05-17 11:39:00 +08:00
yield json . dumps (
{
" response " : response ,
" model " : " rwkv " ,
" choices " : [
{
2023-05-21 13:46:54 +08:00
" delta " : { " content " : delta } ,
2023-05-17 11:39:00 +08:00
" index " : 0 ,
2023-05-21 13:46:54 +08:00
" finish_reason " : None ,
2023-05-17 11:39:00 +08:00
}
] ,
}
)
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-05-24 11:45:55 +08:00
completion_lock . release ( )
2023-05-21 13:46:54 +08:00
if await request . is_disconnected ( ) :
2023-06-03 17:12:59 +08:00
print ( f " { request . client } Stop Waiting " )
quick_log (
request ,
body ,
response + " \n Stop Waiting. RequestsNum: " + str ( requests_num ) ,
)
2023-05-21 13:46:54 +08:00
return
2023-06-03 17:12:59 +08:00
quick_log (
request ,
body ,
response + " \n Finished. RequestsNum: " + str ( requests_num ) ,
)
2023-05-21 13:46:54 +08:00
yield json . dumps (
{
2023-05-17 11:39:00 +08:00
" response " : response ,
" model " : " rwkv " ,
" choices " : [
{
2023-05-21 13:46:54 +08:00
" delta " : { } ,
2023-05-17 11:39:00 +08:00
" index " : 0 ,
" finish_reason " : " stop " ,
}
] ,
}
2023-05-21 13:46:54 +08:00
)
yield " [DONE] "
else :
2023-06-08 13:30:34 +08:00
response = " "
2023-05-28 12:53:14 +08:00
for response , delta in model . generate (
2023-05-22 11:24:57 +08:00
completion_text ,
2023-05-24 14:01:22 +08:00
stop = f " \n \n { user } " if body . stop is None else body . stop ,
2023-05-21 13:46:54 +08:00
) :
if await request . is_disconnected ( ) :
break
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
completion_lock . release ( )
if await request . is_disconnected ( ) :
print ( f " { request . client } Stop Waiting " )
quick_log (
request ,
body ,
response + " \n Stop Waiting. RequestsNum: " + str ( requests_num ) ,
)
return
2023-06-03 17:12:59 +08:00
quick_log (
request ,
body ,
response + " \n Finished. RequestsNum: " + str ( requests_num ) ,
)
2023-05-21 13:46:54 +08:00
yield {
" response " : response ,
" model " : " rwkv " ,
" choices " : [
{
" message " : {
" role " : " assistant " ,
" content " : response ,
} ,
" index " : 0 ,
" finish_reason " : " stop " ,
}
] ,
}
2023-05-07 17:27:54 +08:00
2023-05-17 11:39:00 +08:00
if body . stream :
return EventSourceResponse ( eval_rwkv ( ) )
else :
2023-06-03 17:12:59 +08:00
try :
return await eval_rwkv ( ) . __anext__ ( )
except StopAsyncIteration :
return None
2023-05-22 11:18:37 +08:00
class CompletionBody ( ModelConfigBody ) :
prompt : str
model : str = " rwkv "
stream : bool = False
stop : str = None
2023-06-15 21:52:22 +08:00
class Config :
schema_extra = {
" example " : {
" prompt " : " The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars. \n Chapter 1. \n " ,
" model " : " rwkv " ,
" stream " : False ,
" stop " : None ,
" max_tokens " : 100 ,
" temperature " : 1.2 ,
" top_p " : 0.5 ,
" presence_penalty " : 0.4 ,
" frequency_penalty " : 0.4 ,
}
}
2023-05-22 11:18:37 +08:00
@router.post ( " /v1/completions " )
@router.post ( " /completions " )
async def completions ( body : CompletionBody , request : Request ) :
model : RWKV = global_var . get ( global_var . Model )
if model is None :
raise HTTPException ( status . HTTP_400_BAD_REQUEST , " model not loaded " )
2023-05-28 12:53:14 +08:00
2023-05-27 15:18:12 +08:00
if body . prompt is None or body . prompt == " " :
raise HTTPException ( status . HTTP_400_BAD_REQUEST , " prompt not found " )
2023-05-22 11:18:37 +08:00
async def eval_rwkv ( ) :
2023-06-03 17:12:59 +08:00
global requests_num
requests_num = requests_num + 1
quick_log ( request , None , " Start Waiting. RequestsNum: " + str ( requests_num ) )
2023-05-22 11:18:37 +08:00
while completion_lock . locked ( ) :
2023-05-27 15:18:12 +08:00
if await request . is_disconnected ( ) :
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
print ( f " { request . client } Stop Waiting (Lock) " )
2023-06-03 17:12:59 +08:00
quick_log (
2023-06-03 17:36:50 +08:00
request ,
None ,
" Stop Waiting (Lock). RequestsNum: " + str ( requests_num ) ,
2023-06-03 17:12:59 +08:00
)
2023-05-27 15:18:12 +08:00
return
2023-05-22 11:18:37 +08:00
await asyncio . sleep ( 0.1 )
else :
completion_lock . acquire ( )
2023-06-03 19:28:37 +08:00
if await request . is_disconnected ( ) :
completion_lock . release ( )
requests_num = requests_num - 1
print ( f " { request . client } Stop Waiting (Lock) " )
quick_log (
request ,
None ,
" Stop Waiting (Lock). RequestsNum: " + str ( requests_num ) ,
)
return
2023-05-22 11:18:37 +08:00
set_rwkv_config ( model , global_var . get ( global_var . Model_Config ) )
set_rwkv_config ( model , body )
if body . stream :
2023-06-08 13:30:34 +08:00
response = " "
2023-05-28 12:53:14 +08:00
for response , delta in model . generate ( body . prompt , stop = body . stop ) :
2023-05-22 11:18:37 +08:00
if await request . is_disconnected ( ) :
break
yield json . dumps (
{
" response " : response ,
" model " : " rwkv " ,
" choices " : [
{
" text " : delta ,
" index " : 0 ,
" finish_reason " : None ,
}
] ,
}
)
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-05-24 11:45:55 +08:00
completion_lock . release ( )
2023-05-22 11:18:37 +08:00
if await request . is_disconnected ( ) :
2023-06-03 17:12:59 +08:00
print ( f " { request . client } Stop Waiting " )
quick_log (
request ,
body ,
response + " \n Stop Waiting. RequestsNum: " + str ( requests_num ) ,
)
2023-05-22 11:18:37 +08:00
return
2023-06-03 17:12:59 +08:00
quick_log (
request ,
body ,
response + " \n Finished. RequestsNum: " + str ( requests_num ) ,
)
2023-05-22 11:18:37 +08:00
yield json . dumps (
{
" response " : response ,
" model " : " rwkv " ,
" choices " : [
{
" text " : " " ,
" index " : 0 ,
" finish_reason " : " stop " ,
}
] ,
}
)
yield " [DONE] "
else :
2023-06-08 13:30:34 +08:00
response = " "
2023-05-28 12:53:14 +08:00
for response , delta in model . generate ( body . prompt , stop = body . stop ) :
2023-05-22 11:18:37 +08:00
if await request . is_disconnected ( ) :
break
2023-05-24 11:45:55 +08:00
# torch_gc()
2023-06-03 17:12:59 +08:00
requests_num = requests_num - 1
2023-06-03 17:36:50 +08:00
completion_lock . release ( )
if await request . is_disconnected ( ) :
print ( f " { request . client } Stop Waiting " )
quick_log (
request ,
body ,
response + " \n Stop Waiting. RequestsNum: " + str ( requests_num ) ,
)
return
2023-06-03 17:12:59 +08:00
quick_log (
request ,
body ,
response + " \n Finished. RequestsNum: " + str ( requests_num ) ,
)
2023-05-22 11:18:37 +08:00
yield {
" response " : response ,
" model " : " rwkv " ,
" choices " : [
{
" text " : response ,
" index " : 0 ,
" finish_reason " : " stop " ,
}
] ,
}
if body . stream :
return EventSourceResponse ( eval_rwkv ( ) )
else :
2023-06-03 17:12:59 +08:00
try :
return await eval_rwkv ( ) . __anext__ ( )
except StopAsyncIteration :
return None