This commit is contained in:
josc146
2023-06-03 17:12:59 +08:00
parent f2ec1067bf
commit 38b775c937
4 changed files with 102 additions and 6 deletions

View File

@@ -7,6 +7,7 @@ from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from utils.rwkv import *
from utils.log import quick_log
import global_var
router = APIRouter()
@@ -26,6 +27,8 @@ class ChatCompletionBody(ModelConfigBody):
completion_lock = Lock()
requests_num = 0
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
@@ -106,8 +109,15 @@ The following is a coherent verbose detailed conversation between a girl named {
completion_text += f"{bot}{interface}"
async def eval_rwkv():
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
quick_log(
request, None, "Stop Waiting. RequestsNum: " + str(requests_num)
)
return
await asyncio.sleep(0.1)
else:
@@ -135,9 +145,21 @@ The following is a coherent verbose detailed conversation between a girl named {
}
)
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield json.dumps(
{
"response": response,
@@ -161,6 +183,12 @@ The following is a coherent verbose detailed conversation between a girl named {
if await request.is_disconnected():
break
# torch_gc()
requests_num = requests_num - 1
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
completion_lock.release()
if await request.is_disconnected():
return
@@ -182,7 +210,11 @@ The following is a coherent verbose detailed conversation between a girl named {
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
return await eval_rwkv().__anext__()
try:
return await eval_rwkv().__anext__()
except StopAsyncIteration:
print(f"{request.client} Stop Waiting")
return None
class CompletionBody(ModelConfigBody):
@@ -203,8 +235,15 @@ async def completions(body: CompletionBody, request: Request):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
async def eval_rwkv():
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
quick_log(
request, None, "Stop Waiting. RequestsNum: " + str(requests_num)
)
return
await asyncio.sleep(0.1)
else:
@@ -229,9 +268,21 @@ async def completions(body: CompletionBody, request: Request):
}
)
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield json.dumps(
{
"response": response,
@@ -252,6 +303,12 @@ async def completions(body: CompletionBody, request: Request):
if await request.is_disconnected():
break
# torch_gc()
requests_num = requests_num - 1
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
completion_lock.release()
if await request.is_disconnected():
return
@@ -270,4 +327,8 @@ async def completions(body: CompletionBody, request: Request):
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
return await eval_rwkv().__anext__()
try:
return await eval_rwkv().__anext__()
except StopAsyncIteration:
print(f"{request.client} Stop Waiting")
return None