This commit is contained in:
josc146
2023-05-15 21:55:57 +08:00
parent 80bfb09972
commit 83f0bb503c
13 changed files with 388 additions and 138 deletions

View File

@@ -45,14 +45,14 @@ async def completions(body: CompletionBody, request: Request):
async def eval_rwkv():
if body.stream:
for response, delta in rwkv_generate(model, completion_text):
for response, delta in rwkv_generate(model, completion_text, stop="Bob:"):
if await request.is_disconnected():
break
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(model, completion_text):
for response, delta in rwkv_generate(model, completion_text, stop="Bob:"):
pass
yield json.dumps({"response": response, "model": "rwkv"})
# torch_gc()

View File

@@ -2,7 +2,7 @@ from typing import Dict
from langchain.llms import RWKV
def rwkv_generate(model: RWKV, prompt: str):
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
model.model_state = None
model.model_tokens = []
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
@@ -34,6 +34,11 @@ def rwkv_generate(model: RWKV, prompt: str):
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
response = response.split(stop)[0]
yield response, ""
break
yield response, delta
out_last = begin + i + 1
if i >= model.max_tokens_per_generation - 100: