diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index f7b41b9..3b5af55 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -17,7 +17,7 @@ class Message(BaseModel): content: str -class CompletionBody(ModelConfigBody): +class ChatCompletionBody(ModelConfigBody): messages: List[Message] model: str = "rwkv" stream: bool = False @@ -28,7 +28,7 @@ completion_lock = Lock() @router.post("/v1/chat/completions") @router.post("/chat/completions") -async def completions(body: CompletionBody, request: Request): +async def chat_completions(body: ChatCompletionBody, request: Request): model: RWKV = global_var.get(global_var.Model) if model is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") @@ -135,3 +135,90 @@ async def completions(body: CompletionBody, request: Request): return EventSourceResponse(eval_rwkv()) else: return await eval_rwkv().__anext__() + + +class CompletionBody(ModelConfigBody): + prompt: str + model: str = "rwkv" + stream: bool = False + stop: str = None + + +@router.post("/v1/completions") +@router.post("/completions") +async def completions(body: CompletionBody, request: Request): + model: RWKV = global_var.get(global_var.Model) + if model is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") + + async def eval_rwkv(): + while completion_lock.locked(): + await asyncio.sleep(0.1) + else: + completion_lock.acquire() + set_rwkv_config(model, global_var.get(global_var.Model_Config)) + set_rwkv_config(model, body) + if body.stream: + for response, delta in rwkv_generate( + model, body.prompt, stop=body.stop + ): + if await request.is_disconnected(): + break + yield json.dumps( + { + "response": response, + "model": "rwkv", + "choices": [ + { + "text": delta, + "index": 0, + "finish_reason": None, + } + ], + } + ) + if await request.is_disconnected(): + completion_lock.release() + return + yield json.dumps( + { + "response": response, + "model": "rwkv", + "choices": [ + { + "text": "", + "index": 0, + "finish_reason": "stop", + } + ], + } + ) + yield "[DONE]" + else: + response = None + for response, delta in rwkv_generate( + model, body.prompt, stop=body.stop + ): + if await request.is_disconnected(): + break + if await request.is_disconnected(): + completion_lock.release() + return + yield { + "response": response, + "model": "rwkv", + "choices": [ + { + "text": response, + "index": 0, + "finish_reason": "stop", + } + ], + } + # torch_gc() + completion_lock.release() + + if body.stream: + return EventSourceResponse(eval_rwkv()) + else: + return await eval_rwkv().__anext__()