support for rwkv-4-world

This commit is contained in:
josc146
2023-05-28 12:53:14 +08:00
parent b7fb8ed898
commit 94971bb666
8 changed files with 65918 additions and 65 deletions

View File

@@ -11,10 +11,6 @@ import global_var
router = APIRouter()
interface = ":"
user = "Bob"
bot = "Alice"
class Message(BaseModel):
role: str
@@ -44,17 +40,27 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = f"""
interface = model.interface
user = model.user
bot = model.bot
completion_text = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if user == "Bob"
else ""
)
for message in body.messages:
if message.role == "system":
completion_text = (
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
if user == "Bob"
else ""
+ message.content.replace("\\n", "\n")
.replace("\r\n", "\n")
.replace("\n\n", "\n")
@@ -101,8 +107,7 @@ The following is a coherent verbose detailed conversation between a girl named {
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
model,
for response, delta in model.generate(
completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
@@ -141,8 +146,7 @@ The following is a coherent verbose detailed conversation between a girl named {
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
model,
for response, delta in model.generate(
completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop,
):
@@ -186,7 +190,7 @@ async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.prompt is None or body.prompt == "":
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
@@ -200,9 +204,7 @@ async def completions(body: CompletionBody, request: Request):
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
model, body.prompt, stop=body.stop
):
for response, delta in model.generate(body.prompt, stop=body.stop):
if await request.is_disconnected():
break
yield json.dumps(
@@ -238,9 +240,7 @@ async def completions(body: CompletionBody, request: Request):
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
model, body.prompt, stop=body.stop
):
for response, delta in model.generate(body.prompt, stop=body.stop):
if await request.is_disconnected():
break
# torch_gc()

View File

@@ -11,6 +11,19 @@ import GPUtil
router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
)
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
else:
return default_tokens_path
class SwitchModelBody(BaseModel):
model: str
strategy: str
@@ -36,7 +49,7 @@ def switch_model(body: SwitchModelBody, response: Response):
RWKV(
model=body.model,
strategy=body.strategy,
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
tokens_path=get_tokens_path(body.model),
),
)
except Exception as e: