import os
import pathlib
from typing import Dict
from langchain.llms import RWKV
from pydantic import BaseModel


class ModelConfigBody(BaseModel):
    max_tokens: int = None
    temperature: float = None
    top_p: float = None
    presence_penalty: float = None
    frequency_penalty: float = None


def set_rwkv_config(model: RWKV, body: ModelConfigBody):
    if body.max_tokens:
        model.max_tokens_per_generation = body.max_tokens
    if body.temperature:
        model.temperature = body.temperature
    if body.top_p:
        model.top_p = body.top_p
    if body.presence_penalty:
        model.penalty_alpha_presence = body.presence_penalty
    if body.frequency_penalty:
        model.penalty_alpha_frequency = body.frequency_penalty


def get_rwkv_config(model: RWKV) -> ModelConfigBody:
    return ModelConfigBody(
        max_tokens=model.max_tokens_per_generation,
        temperature=model.temperature,
        top_p=model.top_p,
        presence_penalty=model.penalty_alpha_presence,
        frequency_penalty=model.penalty_alpha_frequency,
    )


os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"


def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
    model.model_state = None
    model.model_tokens = []
    logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
    begin = len(model.model_tokens)
    out_last = begin

    occurrence: Dict = {}

    response = ""
    for i in range(model.max_tokens_per_generation):
        for n in occurrence:
            logits[n] -= (
                model.penalty_alpha_presence
                + occurrence[n] * model.penalty_alpha_frequency
            )
        token = model.pipeline.sample_logits(
            logits, temperature=model.temperature, top_p=model.top_p
        )

        END_OF_TEXT = 0
        if token == END_OF_TEXT:
            break
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1

        logits = model.run_rnn([token])
        delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
        if "\ufffd" not in delta:  # avoid utf-8 display issues
            response += delta
            if stop is not None:
                if stop in response:
                    response = response.split(stop)[0]
                    yield response, ""
                    break
            yield response, delta
            out_last = begin + i + 1