allow multiple systems

This commit is contained in:
josc146 2023-08-04 22:27:55 +08:00
parent 91e2828a95
commit e0b7453883
2 changed files with 75 additions and 61 deletions

View File

@ -2,11 +2,12 @@ import asyncio
import json import json
from threading import Lock from threading import Lock
from typing import List, Union from typing import List, Union
from enum import Enum
import base64 import base64
from fastapi import APIRouter, Request, status, HTTPException from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel from pydantic import BaseModel, Field
import numpy as np import numpy as np
import tiktoken import tiktoken
from utils.rwkv import * from utils.rwkv import *
@ -16,9 +17,15 @@ import global_var
router = APIRouter() router = APIRouter()
class Role(Enum):
User = "user"
Assistant = "assistant"
System = "system"
class Message(BaseModel): class Message(BaseModel):
role: str role: Role
content: str content: str = Field(min_length=1)
class ChatCompletionBody(ModelConfigBody): class ChatCompletionBody(ModelConfigBody):
@ -38,7 +45,7 @@ class ChatCompletionBody(ModelConfigBody):
class Config: class Config:
schema_extra = { schema_extra = {
"example": { "example": {
"messages": [{"role": "user", "content": "hello"}], "messages": [{"role": Role.User.value, "content": "hello"}],
"model": "rwkv", "model": "rwkv",
"stream": False, "stream": False,
"stop": None, "stop": None,
@ -200,7 +207,7 @@ async def eval_rwkv(
"choices": [ "choices": [
{ {
"message": { "message": {
"role": "assistant", "role": Role.Assistant.value,
"content": response, "content": response,
}, },
"index": 0, "index": 0,
@ -223,17 +230,12 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
if model is None: if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
question = body.messages[-1] if body.messages is None or body.messages == []:
if question.role == "user": raise HTTPException(status.HTTP_400_BAD_REQUEST, "messages not found")
question = question.content
elif question.role == "system": basic_system: str = ""
question = body.messages[-2] if body.messages[0].role == Role.System:
if question.role == "user": basic_system = body.messages[0].content
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
interface = model.interface interface = model.interface
user = model.user if body.user_name is None else body.user_name user = model.user if body.user_name is None else body.user_name
@ -241,6 +243,8 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
is_raven = model.rwkv_type == RWKVType.Raven is_raven = model.rwkv_type == RWKVType.Raven
completion_text: str = ""
if basic_system == "":
completion_text = ( completion_text = (
f""" f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \ The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
@ -250,17 +254,20 @@ The following is a coherent verbose detailed conversation between a girl named {
{bot} usually gives {user} kind, helpful and informative advices.\n {bot} usually gives {user} kind, helpful and informative advices.\n
""" """
if is_raven if is_raven
else f"{user}{interface} hi\n\n{bot}{interface} Hi. " else (
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n" + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
) )
for message in body.messages: )
if message.role == "system": elif basic_system != "":
completion_text = ( completion_text = (
(
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. " f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
if is_raven if is_raven
else f"{user}{interface} hi\n\n{bot}{interface} Hi. " else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ message.content.replace("\\n", "\n") )
.replace("\r\n", "\n") + basic_system.replace("\r\n", "\n")
.replace("\r", "\n")
.replace("\n\n", "\n") .replace("\n\n", "\n")
.replace("\n", " ") .replace("\n", " ")
.strip() .strip()
@ -275,22 +282,18 @@ The following is a coherent verbose detailed conversation between a girl named {
.replace("", f"{bot}" if is_raven else "") .replace("", f"{bot}" if is_raven else "")
+ "\n\n" + "\n\n"
) )
break
for message in body.messages: for message in body.messages[(0 if basic_system == "" else 1) :]:
if message.role == "user": append_message: str = ""
if message.role == Role.User:
append_message = f"{user}{interface} " + message.content
elif message.role == Role.Assistant:
append_message = f"{bot}{interface} " + message.content
elif message.role == Role.System:
append_message = message.content
completion_text += ( completion_text += (
f"{user}{interface} " append_message.replace("\r\n", "\n")
+ message.content.replace("\\n", "\n") .replace("\r", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
.strip()
+ "\n\n"
)
elif message.role == "assistant":
completion_text += (
f"{bot}{interface} "
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n") .replace("\n\n", "\n")
.strip() .strip()
+ "\n\n" + "\n\n"

View File

@ -2,6 +2,8 @@ import json
import logging import logging
from typing import Any from typing import Any
from fastapi import Request from fastapi import Request
from pydantic import BaseModel
from enum import Enum
logger = logging.getLogger() logger = logging.getLogger()
@ -14,12 +16,21 @@ fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
class ClsEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, BaseModel):
return obj.dict()
if isinstance(obj, Enum):
return obj.value
return super().default(obj)
def quick_log(request: Request, body: Any, response: str): def quick_log(request: Request, body: Any, response: str):
try: try:
logger.info( logger.info(
f"Client: {request.client if request else ''}\nUrl: {request.url if request else ''}\n" f"Client: {request.client if request else ''}\nUrl: {request.url if request else ''}\n"
+ ( + (
f"Body: {json.dumps(body.__dict__, default=vars, ensure_ascii=False)}\n" f"Body: {json.dumps(body.__dict__, ensure_ascii=False, cls=ClsEncoder)}\n"
if body if body
else "" else ""
) )