support for rwkv-4-world

This commit is contained in:
josc146
2023-05-28 12:53:14 +08:00
parent b7fb8ed898
commit 94971bb666
8 changed files with 65918 additions and 65 deletions

View File

@@ -11,10 +11,6 @@ import global_var
router = APIRouter()
interface = ":"
user = "Bob"
bot = "Alice"
class Message(BaseModel):
role: str
@@ -44,17 +40,27 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = f"""
interface = model.interface
user = model.user
bot = model.bot
completion_text = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if user == "Bob"
else ""
)
for message in body.messages:
if message.role == "system":
completion_text = (
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
if user == "Bob"
else ""
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
@@ -101,8 +107,7 @@ The following is a coherent verbose detailed conversation between a girl named {
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
model,
for response, delta in model.generate(
completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
@@ -141,8 +146,7 @@ The following is a coherent verbose detailed conversation between a girl named {
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
model,
for response, delta in model.generate(
completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
@@ -186,7 +190,7 @@ async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.prompt is None or body.prompt == "":
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
@@ -200,9 +204,7 @@ async def completions(body: CompletionBody, request: Request):
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
model, body.prompt, stop=body.stop
):
for response, delta in model.generate(body.prompt, stop=body.stop):
if await request.is_disconnected():
break
yield json.dumps(
@@ -238,9 +240,7 @@ async def completions(body: CompletionBody, request: Request):
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
model, body.prompt, stop=body.stop
):
for response, delta in model.generate(body.prompt, stop=body.stop):
if await request.is_disconnected():
break
# torch_gc()