add compatible /v1/completions API
This commit is contained in:
		
							parent
							
								
									bbad153ecb
								
							
						
					
					
						commit
						85493da730
					
				@ -17,7 +17,7 @@ class Message(BaseModel):
 | 
			
		||||
    content: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CompletionBody(ModelConfigBody):
 | 
			
		||||
class ChatCompletionBody(ModelConfigBody):
 | 
			
		||||
    messages: List[Message]
 | 
			
		||||
    model: str = "rwkv"
 | 
			
		||||
    stream: bool = False
 | 
			
		||||
@ -28,7 +28,7 @@ completion_lock = Lock()
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/chat/completions")
 | 
			
		||||
@router.post("/chat/completions")
 | 
			
		||||
async def completions(body: CompletionBody, request: Request):
 | 
			
		||||
async def chat_completions(body: ChatCompletionBody, request: Request):
 | 
			
		||||
    model: RWKV = global_var.get(global_var.Model)
 | 
			
		||||
    if model is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
 | 
			
		||||
@ -135,3 +135,90 @@ async def completions(body: CompletionBody, request: Request):
 | 
			
		||||
        return EventSourceResponse(eval_rwkv())
 | 
			
		||||
    else:
 | 
			
		||||
        return await eval_rwkv().__anext__()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CompletionBody(ModelConfigBody):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    model: str = "rwkv"
 | 
			
		||||
    stream: bool = False
 | 
			
		||||
    stop: str = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/completions")
 | 
			
		||||
@router.post("/completions")
 | 
			
		||||
async def completions(body: CompletionBody, request: Request):
 | 
			
		||||
    model: RWKV = global_var.get(global_var.Model)
 | 
			
		||||
    if model is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
 | 
			
		||||
 | 
			
		||||
    async def eval_rwkv():
 | 
			
		||||
        while completion_lock.locked():
 | 
			
		||||
            await asyncio.sleep(0.1)
 | 
			
		||||
        else:
 | 
			
		||||
            completion_lock.acquire()
 | 
			
		||||
            set_rwkv_config(model, global_var.get(global_var.Model_Config))
 | 
			
		||||
            set_rwkv_config(model, body)
 | 
			
		||||
            if body.stream:
 | 
			
		||||
                for response, delta in rwkv_generate(
 | 
			
		||||
                    model, body.prompt, stop=body.stop
 | 
			
		||||
                ):
 | 
			
		||||
                    if await request.is_disconnected():
 | 
			
		||||
                        break
 | 
			
		||||
                    yield json.dumps(
 | 
			
		||||
                        {
 | 
			
		||||
                            "response": response,
 | 
			
		||||
                            "model": "rwkv",
 | 
			
		||||
                            "choices": [
 | 
			
		||||
                                {
 | 
			
		||||
                                    "text": delta,
 | 
			
		||||
                                    "index": 0,
 | 
			
		||||
                                    "finish_reason": None,
 | 
			
		||||
                                }
 | 
			
		||||
                            ],
 | 
			
		||||
                        }
 | 
			
		||||
                    )
 | 
			
		||||
                if await request.is_disconnected():
 | 
			
		||||
                    completion_lock.release()
 | 
			
		||||
                    return
 | 
			
		||||
                yield json.dumps(
 | 
			
		||||
                    {
 | 
			
		||||
                        "response": response,
 | 
			
		||||
                        "model": "rwkv",
 | 
			
		||||
                        "choices": [
 | 
			
		||||
                            {
 | 
			
		||||
                                "text": "",
 | 
			
		||||
                                "index": 0,
 | 
			
		||||
                                "finish_reason": "stop",
 | 
			
		||||
                            }
 | 
			
		||||
                        ],
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                yield "[DONE]"
 | 
			
		||||
            else:
 | 
			
		||||
                response = None
 | 
			
		||||
                for response, delta in rwkv_generate(
 | 
			
		||||
                    model, body.prompt, stop=body.stop
 | 
			
		||||
                ):
 | 
			
		||||
                    if await request.is_disconnected():
 | 
			
		||||
                        break
 | 
			
		||||
                if await request.is_disconnected():
 | 
			
		||||
                    completion_lock.release()
 | 
			
		||||
                    return
 | 
			
		||||
                yield {
 | 
			
		||||
                    "response": response,
 | 
			
		||||
                    "model": "rwkv",
 | 
			
		||||
                    "choices": [
 | 
			
		||||
                        {
 | 
			
		||||
                            "text": response,
 | 
			
		||||
                            "index": 0,
 | 
			
		||||
                            "finish_reason": "stop",
 | 
			
		||||
                        }
 | 
			
		||||
                    ],
 | 
			
		||||
                }
 | 
			
		||||
            # torch_gc()
 | 
			
		||||
            completion_lock.release()
 | 
			
		||||
 | 
			
		||||
    if body.stream:
 | 
			
		||||
        return EventSourceResponse(eval_rwkv())
 | 
			
		||||
    else:
 | 
			
		||||
        return await eval_rwkv().__anext__()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user