refactor completions api

This commit is contained in:
josc146 2023-06-18 20:16:52 +08:00
parent fcdda71b46
commit 967be6f88f
2 changed files with 165 additions and 257 deletions

View File

@ -40,11 +40,159 @@ class ChatCompletionBody(ModelConfigBody):
}
class CompletionBody(ModelConfigBody):
prompt: str
model: str = "rwkv"
stream: bool = False
stop: str = None
class Config:
schema_extra = {
"example": {
"prompt": "The following is an epic science fiction masterpiece that is immortalized, "
+ "with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n",
"model": "rwkv",
"stream": False,
"stop": None,
"max_tokens": 100,
"temperature": 1.2,
"top_p": 0.5,
"presence_penalty": 0.4,
"frequency_penalty": 0.4,
}
}
completion_lock = Lock()
requests_num = 0
async def eval_rwkv(
model: RWKV,
request: Request,
body: ModelConfigBody,
prompt: str,
stream: bool,
stop: str,
chat_mode: bool,
):
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
response = ""
for response, delta in model.generate(
prompt,
stop=stop,
):
if await request.is_disconnected():
break
if stream:
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {"content": delta},
"index": 0,
"finish_reason": None,
}
if chat_mode
else {
"text": delta,
"index": 0,
"finish_reason": None,
}
],
}
)
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
if stream:
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {},
"index": 0,
"finish_reason": "stop",
}
if chat_mode
else {
"text": "",
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"message": {
"role": "assistant",
"content": response,
},
"index": 0,
"finish_reason": "stop",
}
if chat_mode
else {
"text": response,
"index": 0,
"finish_reason": "stop",
}
],
}
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def chat_completions(body: ChatCompletionBody, request: Request):
@ -77,7 +225,8 @@ The following is a coherent verbose detailed conversation between a girl named {
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if user == "Bob"
else f"{user}{interface} hi\n\n{bot}{interface} Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
for message in body.messages:
if message.role == "system":
@ -123,156 +272,20 @@ The following is a coherent verbose detailed conversation between a girl named {
)
completion_text += f"{bot}{interface}"
async def eval_rwkv():
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
stop = f"\n\n{user}" if body.stop is None else body.stop
if body.stream:
response = ""
for response, delta in model.generate(
completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
if await request.is_disconnected():
break
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {"content": delta},
"index": 0,
"finish_reason": None,
}
],
}
return EventSourceResponse(
eval_rwkv(model, request, body, completion_text, body.stream, stop, True)
)
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {},
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
response = ""
for response, delta in model.generate(
completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
if await request.is_disconnected():
break
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"message": {
"role": "assistant",
"content": response,
},
"index": 0,
"finish_reason": "stop",
}
],
}
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
try:
return await eval_rwkv().__anext__()
return await eval_rwkv(
model, request, body, completion_text, body.stream, stop, True
).__anext__()
except StopAsyncIteration:
return None
class CompletionBody(ModelConfigBody):
prompt: str
model: str = "rwkv"
stream: bool = False
stop: str = None
class Config:
schema_extra = {
"example": {
"prompt": "The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n",
"model": "rwkv",
"stream": False,
"stop": None,
"max_tokens": 100,
"temperature": 1.2,
"top_p": 0.5,
"presence_penalty": 0.4,
"frequency_penalty": 0.4,
}
}
@router.post("/v1/completions")
@router.post("/completions")
async def completions(body: CompletionBody, request: Request):
@ -283,120 +296,14 @@ async def completions(body: CompletionBody, request: Request):
if body.prompt is None or body.prompt == "":
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
async def eval_rwkv():
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
response = ""
for response, delta in model.generate(body.prompt, stop=body.stop):
if await request.is_disconnected():
break
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"text": delta,
"index": 0,
"finish_reason": None,
}
],
}
return EventSourceResponse(
eval_rwkv(model, request, body, body.prompt, body.stream, body.stop, False)
)
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"text": "",
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
response = ""
for response, delta in model.generate(body.prompt, stop=body.stop):
if await request.is_disconnected():
break
# torch_gc()
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"text": response,
"index": 0,
"finish_reason": "stop",
}
],
}
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
try:
return await eval_rwkv().__anext__()
return await eval_rwkv(
model, request, body, body.prompt, body.stream, body.stop, False
).__anext__()
except StopAsyncIteration:
return None

View File

@ -64,7 +64,8 @@ The following is a coherent verbose detailed conversation between a girl named {
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if self.user == "Bob"
else f"{user}{interface} hi\n\n{bot}{interface} Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try: