update
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
from langchain.llms import RWKV
|
||||
|
||||
|
||||
def rwkv_generate(model: RWKV, prompt: str):
|
||||
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
||||
model.model_state = None
|
||||
model.model_tokens = []
|
||||
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
|
||||
@@ -34,6 +34,11 @@ def rwkv_generate(model: RWKV, prompt: str):
|
||||
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
|
||||
if "\ufffd" not in delta: # avoid utf-8 display issues
|
||||
response += delta
|
||||
if stop is not None:
|
||||
if stop in response:
|
||||
response = response.split(stop)[0]
|
||||
yield response, ""
|
||||
break
|
||||
yield response, delta
|
||||
out_last = begin + i + 1
|
||||
if i >= model.max_tokens_per_generation - 100:
|
||||
|
||||
Reference in New Issue
Block a user