RWKV-Runner/backend-python/rwkv_helper.py
2023-05-06 20:17:39 +08:00

41 lines
1.2 KiB
Python

from typing import Dict
from langchain.llms import RWKV
def rwkv_generate(model: RWKV, prompt: str):
model.model_state = None
model.model_tokens = []
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
begin = len(model.model_tokens)
out_last = begin
occurrence: Dict = {}
response = ""
for i in range(model.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
model.penalty_alpha_presence
+ occurrence[n] * model.penalty_alpha_frequency
)
token = model.pipeline.sample_logits(
logits, temperature=model.temperature, top_p=model.top_p
)
END_OF_TEXT = 0
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = model.run_rnn([token])
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
yield response, delta
out_last = begin + i + 1
if i >= model.max_tokens_per_generation - 100:
break