add role: "system" support

This commit is contained in:
josc146 2023-05-24 14:01:22 +08:00
parent 1176dba282
commit bcb38d991a

View File

@ -11,6 +11,10 @@ import global_var
router = APIRouter()
interface = ":"
user = "Bob"
bot = "Alice"
class Message(BaseModel):
role: str
@ -40,11 +44,36 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = ""
completion_text = f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
for message in body.messages:
if message.role == "user":
if message.role == "system":
completion_text = (
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
.replace("\n", " ")
.strip()
.replace("You are", f"{bot} is")
.replace("you are", f"{bot} is")
.replace("You're", f"{bot} is")
.replace("you're", f"{bot} is")
.replace("You", f"{bot}")
.replace("you", f"{bot}")
.replace("Your", f"{bot}'s")
.replace("your", f"{bot}'s")
.replace("", f"{bot}")
+ "\n\n"
)
elif message.role == "user":
completion_text += (
"Bob: "
f"{user}{interface} "
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
@ -53,14 +82,14 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
)
elif message.role == "assistant":
completion_text += (
"Alice: "
f"{bot}{interface} "
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
.strip()
+ "\n\n"
)
completion_text += "Alice:"
completion_text += f"{bot}{interface}"
async def eval_rwkv():
while completion_lock.locked():
@ -73,7 +102,7 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
for response, delta in rwkv_generate(
model,
completion_text,
stop="\n\nBob" if body.stop is None else body.stop,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
if await request.is_disconnected():
break
@ -113,7 +142,7 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
for response, delta in rwkv_generate(
model,
completion_text,
stop="\n\nBob" if body.stop is None else body.stop,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
if await request.is_disconnected():
break