81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
import os
|
|
import pathlib
|
|
from typing import Dict
|
|
from langchain.llms import RWKV
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class ModelConfigBody(BaseModel):
|
|
max_tokens: int = None
|
|
temperature: float = None
|
|
top_p: float = None
|
|
presence_penalty: float = None
|
|
frequency_penalty: float = None
|
|
|
|
|
|
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
|
|
if body.max_tokens:
|
|
model.max_tokens_per_generation = body.max_tokens
|
|
if body.temperature:
|
|
model.temperature = body.temperature
|
|
if body.top_p:
|
|
model.top_p = body.top_p
|
|
if body.presence_penalty:
|
|
model.penalty_alpha_presence = body.presence_penalty
|
|
if body.frequency_penalty:
|
|
model.penalty_alpha_frequency = body.frequency_penalty
|
|
|
|
|
|
def get_rwkv_config(model: RWKV) -> ModelConfigBody:
|
|
return ModelConfigBody(
|
|
max_tokens=model.max_tokens_per_generation,
|
|
temperature=model.temperature,
|
|
top_p=model.top_p,
|
|
presence_penalty=model.penalty_alpha_presence,
|
|
frequency_penalty=model.penalty_alpha_frequency,
|
|
)
|
|
|
|
|
|
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
|
|
|
|
|
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
|
model.model_state = None
|
|
model.model_tokens = []
|
|
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
|
|
begin = len(model.model_tokens)
|
|
out_last = begin
|
|
|
|
occurrence: Dict = {}
|
|
|
|
response = ""
|
|
for i in range(model.max_tokens_per_generation):
|
|
for n in occurrence:
|
|
logits[n] -= (
|
|
model.penalty_alpha_presence
|
|
+ occurrence[n] * model.penalty_alpha_frequency
|
|
)
|
|
token = model.pipeline.sample_logits(
|
|
logits, temperature=model.temperature, top_p=model.top_p
|
|
)
|
|
|
|
END_OF_TEXT = 0
|
|
if token == END_OF_TEXT:
|
|
break
|
|
if token not in occurrence:
|
|
occurrence[token] = 1
|
|
else:
|
|
occurrence[token] += 1
|
|
|
|
logits = model.run_rnn([token])
|
|
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
|
|
if "\ufffd" not in delta: # avoid utf-8 display issues
|
|
response += delta
|
|
if stop is not None:
|
|
if stop in response:
|
|
response = response.split(stop)[0]
|
|
yield response, ""
|
|
break
|
|
yield response, delta
|
|
out_last = begin + i + 1
|