RWKV-Runner/backend-python/utils/rwkv.py

132 lines
4.0 KiB
Python
Raw Normal View History

import os
import pathlib
2023-05-28 12:53:14 +08:00
from typing import Dict, List
2023-05-17 11:39:00 +08:00
from pydantic import BaseModel
2023-05-28 12:53:14 +08:00
from rwkv_pip.utils import PIPELINE
END_OF_TEXT = 0
END_OF_LINE = 187
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
class RWKV:
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
self.model = Model(model, strategy)
self.pipeline = PIPELINE(self.model, tokens_path)
self.model_state = None
self.model_tokens = []
self.CHUNK_LEN = 256
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.5
self.penalty_alpha_presence = 0.4
self.penalty_alpha_frequency = 0.4
self.interface = ":"
if "rwkv_vocab" in tokens_path:
self.user = "Human"
self.bot = "Bot"
else:
self.user = "Bob"
self.bot = "Alice"
self.AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_REPEAT_TOKENS += dd
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
tokens = [int(x) for x in _tokens]
self.model_tokens += tokens
while len(tokens) > 0:
out, self.model_state = self.model.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
out[END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out
def generate(self, prompt: str, stop: str = None):
self.model_state = None
self.model_tokens = []
logits = self.run_rnn(self.pipeline.encode(prompt))
begin = len(self.model_tokens)
out_last = begin
occurrence: Dict = {}
response = ""
for i in range(self.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
token = self.pipeline.sample_logits(
logits, temperature=self.temperature, top_p=self.top_p
)
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = self.run_rnn([token])
delta: str = self.pipeline.decode(self.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
response = response.split(stop)[0]
yield response, ""
break
out_last = begin + i + 1
yield response, delta
2023-05-17 11:39:00 +08:00
class ModelConfigBody(BaseModel):
max_tokens: int = None
temperature: float = None
top_p: float = None
presence_penalty: float = None
frequency_penalty: float = None
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
if body.max_tokens:
model.max_tokens_per_generation = body.max_tokens
if body.temperature:
model.temperature = body.temperature
if body.top_p:
model.top_p = body.top_p
if body.presence_penalty:
model.penalty_alpha_presence = body.presence_penalty
if body.frequency_penalty:
model.penalty_alpha_frequency = body.frequency_penalty
def get_rwkv_config(model: RWKV) -> ModelConfigBody:
return ModelConfigBody(
max_tokens=model.max_tokens_per_generation,
temperature=model.temperature,
top_p=model.top_p,
presence_penalty=model.penalty_alpha_presence,
frequency_penalty=model.penalty_alpha_frequency,
)