add compatible /v1/completions API

This commit is contained in:
josc146 2023-05-22 11:18:37 +08:00
parent bbad153ecb
commit 85493da730

View File

@ -17,7 +17,7 @@ class Message(BaseModel):
content: str content: str
class CompletionBody(ModelConfigBody): class ChatCompletionBody(ModelConfigBody):
messages: List[Message] messages: List[Message]
model: str = "rwkv" model: str = "rwkv"
stream: bool = False stream: bool = False
@ -28,7 +28,7 @@ completion_lock = Lock()
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@router.post("/chat/completions") @router.post("/chat/completions")
async def completions(body: CompletionBody, request: Request): async def chat_completions(body: ChatCompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model) model: RWKV = global_var.get(global_var.Model)
if model is None: if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
@ -135,3 +135,90 @@ async def completions(body: CompletionBody, request: Request):
return EventSourceResponse(eval_rwkv()) return EventSourceResponse(eval_rwkv())
else: else:
return await eval_rwkv().__anext__() return await eval_rwkv().__anext__()
class CompletionBody(ModelConfigBody):
prompt: str
model: str = "rwkv"
stream: bool = False
stop: str = None
@router.post("/v1/completions")
@router.post("/completions")
async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
async def eval_rwkv():
while completion_lock.locked():
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
model, body.prompt, stop=body.stop
):
if await request.is_disconnected():
break
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"text": delta,
"index": 0,
"finish_reason": None,
}
],
}
)
if await request.is_disconnected():
completion_lock.release()
return
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"text": "",
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
model, body.prompt, stop=body.stop
):
if await request.is_disconnected():
break
if await request.is_disconnected():
completion_lock.release()
return
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"text": response,
"index": 0,
"finish_reason": "stop",
}
],
}
# torch_gc()
completion_lock.release()
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
return await eval_rwkv().__anext__()