new chat template for /chat/completions (better system support)

This commit is contained in:
josc146 2024-03-25 12:52:40 +08:00
parent a93610e574
commit 16f2201d9f

View File

@ -53,6 +53,9 @@ class ChatCompletionBody(ModelConfigBody):
assistant_name: Union[str, None] = Field( assistant_name: Union[str, None] = Field(
None, description="Internal assistant name", min_length=1 None, description="Internal assistant name", min_length=1
) )
system_name: Union[str, None] = Field(
None, description="Internal system name", min_length=1
)
presystem: bool = Field( presystem: bool = Field(
True, description="Whether to insert default system prompt at the beginning" True, description="Whether to insert default system prompt at the beginning"
) )
@ -68,6 +71,7 @@ class ChatCompletionBody(ModelConfigBody):
"stop": None, "stop": None,
"user_name": None, "user_name": None,
"assistant_name": None, "assistant_name": None,
"system_name": None,
"presystem": True, "presystem": True,
"max_tokens": 1000, "max_tokens": 1000,
"temperature": 1, "temperature": 1,
@ -252,20 +256,9 @@ async def eval_rwkv(
} }
@router.post("/v1/chat/completions", tags=["Completions"]) def chat_template_old(
@router.post("/chat/completions", tags=["Completions"]) model: TextRWKV, body: ChatCompletionBody, interface: str, user: str, bot: str
async def chat_completions(body: ChatCompletionBody, request: Request): ):
model: TextRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.messages is None or body.messages == []:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "messages not found")
interface = model.interface
user = model.user if body.user_name is None else body.user_name
bot = model.bot if body.assistant_name is None else body.assistant_name
is_raven = model.rwkv_type == RWKVType.Raven is_raven = model.rwkv_type == RWKVType.Raven
completion_text: str = "" completion_text: str = ""
@ -334,6 +327,53 @@ The following is a coherent verbose detailed conversation between a girl named {
completion_text += append_message + "\n\n" completion_text += append_message + "\n\n"
completion_text += f"{bot}{interface}" completion_text += f"{bot}{interface}"
return completion_text
def chat_template(
model: TextRWKV, body: ChatCompletionBody, interface: str, user: str, bot: str
):
completion_text: str = ""
if body.presystem:
completion_text = (
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
system = "System" if body.system_name is None else body.system_name
for message in body.messages:
append_message: str = ""
if message.role == Role.User:
append_message = f"{user}{interface} " + message.content
elif message.role == Role.Assistant:
append_message = f"{bot}{interface} " + message.content
elif message.role == Role.System:
append_message = f"{system}{interface} " + message.content
completion_text += append_message + "\n\n"
completion_text += f"{bot}{interface}"
return completion_text
@router.post("/v1/chat/completions", tags=["Completions"])
@router.post("/chat/completions", tags=["Completions"])
async def chat_completions(body: ChatCompletionBody, request: Request):
model: TextRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.messages is None or body.messages == []:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "messages not found")
interface = model.interface
user = model.user if body.user_name is None else body.user_name
bot = model.bot if body.assistant_name is None else body.assistant_name
if model.version < 5:
completion_text = chat_template_old(model, body, interface, user, bot)
else:
completion_text = chat_template(model, body, interface, user, bot)
user_code = model.pipeline.decode([model.pipeline.encode(user)[0]]) user_code = model.pipeline.decode([model.pipeline.encode(user)[0]])
bot_code = model.pipeline.decode([model.pipeline.encode(bot)[0]]) bot_code = model.pipeline.decode([model.pipeline.encode(bot)[0]])
if type(body.stop) == str: if type(body.stop) == str:
@ -343,8 +383,8 @@ The following is a coherent verbose detailed conversation between a girl named {
body.stop.append(f"\n\n{bot_code}") body.stop.append(f"\n\n{bot_code}")
elif body.stop is None: elif body.stop is None:
body.stop = default_stop body.stop = default_stop
if not body.presystem: # if not body.presystem:
body.stop.append("\n\n") # body.stop.append("\n\n")
if body.stream: if body.stream:
return EventSourceResponse( return EventSourceResponse(