add compatible /v1/completions API
This commit is contained in:
parent
bbad153ecb
commit
85493da730
@ -17,7 +17,7 @@ class Message(BaseModel):
|
|||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class CompletionBody(ModelConfigBody):
|
class ChatCompletionBody(ModelConfigBody):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model: str = "rwkv"
|
model: str = "rwkv"
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
@ -28,7 +28,7 @@ completion_lock = Lock()
|
|||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
@router.post("/chat/completions")
|
@router.post("/chat/completions")
|
||||||
async def completions(body: CompletionBody, request: Request):
|
async def chat_completions(body: ChatCompletionBody, request: Request):
|
||||||
model: RWKV = global_var.get(global_var.Model)
|
model: RWKV = global_var.get(global_var.Model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
||||||
@ -135,3 +135,90 @@ async def completions(body: CompletionBody, request: Request):
|
|||||||
return EventSourceResponse(eval_rwkv())
|
return EventSourceResponse(eval_rwkv())
|
||||||
else:
|
else:
|
||||||
return await eval_rwkv().__anext__()
|
return await eval_rwkv().__anext__()
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionBody(ModelConfigBody):
|
||||||
|
prompt: str
|
||||||
|
model: str = "rwkv"
|
||||||
|
stream: bool = False
|
||||||
|
stop: str = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/completions")
|
||||||
|
@router.post("/completions")
|
||||||
|
async def completions(body: CompletionBody, request: Request):
|
||||||
|
model: RWKV = global_var.get(global_var.Model)
|
||||||
|
if model is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
||||||
|
|
||||||
|
async def eval_rwkv():
|
||||||
|
while completion_lock.locked():
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
completion_lock.acquire()
|
||||||
|
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||||
|
set_rwkv_config(model, body)
|
||||||
|
if body.stream:
|
||||||
|
for response, delta in rwkv_generate(
|
||||||
|
model, body.prompt, stop=body.stop
|
||||||
|
):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"response": response,
|
||||||
|
"model": "rwkv",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": delta,
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if await request.is_disconnected():
|
||||||
|
completion_lock.release()
|
||||||
|
return
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"response": response,
|
||||||
|
"model": "rwkv",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield "[DONE]"
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
for response, delta in rwkv_generate(
|
||||||
|
model, body.prompt, stop=body.stop
|
||||||
|
):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
if await request.is_disconnected():
|
||||||
|
completion_lock.release()
|
||||||
|
return
|
||||||
|
yield {
|
||||||
|
"response": response,
|
||||||
|
"model": "rwkv",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": response,
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# torch_gc()
|
||||||
|
completion_lock.release()
|
||||||
|
|
||||||
|
if body.stream:
|
||||||
|
return EventSourceResponse(eval_rwkv())
|
||||||
|
else:
|
||||||
|
return await eval_rwkv().__anext__()
|
||||||
|
Loading…
Reference in New Issue
Block a user