This commit is contained in:
josc146
2023-05-15 21:55:57 +08:00
parent 80bfb09972
commit 83f0bb503c
13 changed files with 388 additions and 138 deletions

View File

@@ -2,7 +2,7 @@ from typing import Dict
from langchain.llms import RWKV
def rwkv_generate(model: RWKV, prompt: str):
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
model.model_state = None
model.model_tokens = []
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
@@ -34,6 +34,11 @@ def rwkv_generate(model: RWKV, prompt: str):
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
response = response.split(stop)[0]
yield response, ""
break
yield response, delta
out_last = begin + i + 1
if i >= model.max_tokens_per_generation - 100: