########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import os, sys
import numpy as np
import torch
from torch.nn import functional as F


class PIPELINE_ARGS:
    def __init__(
        self,
        temperature=1.0,
        top_p=0.85,
        top_k=0,
        alpha_frequency=0.2,
        alpha_presence=0.2,
        token_ban=[],
        token_stop=[],
        chunk_len=256,
    ):
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k
        self.alpha_frequency = alpha_frequency  # Frequency Penalty (as in GPT-3)
        self.alpha_presence = alpha_presence  # Presence Penalty (as in GPT-3)
        self.token_ban = token_ban  # ban the generation of some tokens
        self.token_stop = token_stop  # stop generation whenever you see any token here
        self.chunk_len = (
            chunk_len  # split input into chunks to save VRAM (shorter -> slower)
        )


class PIPELINE:
    def __init__(self, model, WORD_NAME):
        self.model = model
        if WORD_NAME == "cl100k_base":
            import tiktoken

            self.tokenizer = tiktoken.get_encoding(WORD_NAME)
        elif WORD_NAME == "rwkv_vocab_v20230424":
            sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
            from rwkv_tokenizer import TRIE_TOKENIZER

            self.tokenizer = TRIE_TOKENIZER(
                os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
            )
        else:
            from tokenizers import Tokenizer

            self.tokenizer = Tokenizer.from_file(WORD_NAME)

    def refine_context(self, context):
        context = context.strip().split("\n")
        for c in range(len(context)):
            context[c] = context[c].strip().strip("\u3000").strip("\r")
        context = list(filter(lambda c: c != "", context))
        context = "\n" + ("\n".join(context)).strip()
        if context == "":
            context = "\n"
        return context

    def encode(self, x):
        if "Tokenizer" in str(type(self.tokenizer)):
            return self.tokenizer.encode(x).ids
        else:
            return self.tokenizer.encode(x)

    def decode(self, x):
        return self.tokenizer.decode(x)

    def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
        probs = F.softmax(logits.float(), dim=-1)
        top_k = int(top_k)
        if probs.device == torch.device("cpu"):
            probs = probs.numpy()
            sorted_ids = np.argsort(probs)
            sorted_probs = probs[sorted_ids][::-1]
            cumulative_probs = np.cumsum(sorted_probs)
            cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
            probs[probs < cutoff] = 0
            if top_k < len(probs) and top_k > 0:
                probs[sorted_ids[:-top_k]] = 0
            if temperature != 1.0:
                probs = probs ** (1.0 / temperature)
            probs = probs / np.sum(probs)
            out = np.random.choice(a=len(probs), p=probs)
            return int(out)
        else:
            sorted_ids = torch.argsort(probs)
            sorted_probs = probs[sorted_ids]
            sorted_probs = torch.flip(sorted_probs, dims=(0,))
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
            cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
            probs[probs < cutoff] = 0
            if top_k < len(probs) and top_k > 0:
                probs[sorted_ids[:-top_k]] = 0
            if temperature != 1.0:
                probs = probs ** (1.0 / temperature)
            out = torch.multinomial(probs, num_samples=1)[0]
            return int(out)

    def generate(
        self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
    ):
        all_tokens = []
        out_last = 0
        out_str = ""
        occurrence = {}
        for i in range(token_count):
            # forward & adjust prob.
            tokens = self.encode(ctx) if i == 0 else [token]
            while len(tokens) > 0:
                out, state = self.model.forward(tokens[: args.chunk_len], state)
                tokens = tokens[args.chunk_len :]

            for n in args.token_ban:
                out[n] = -float("inf")
            for n in occurrence:
                out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency

            # sampler
            token = self.sample_logits(
                out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
            )
            if token in args.token_stop:
                break
            all_tokens += [token]
            if token not in occurrence:
                occurrence[token] = 1
            else:
                occurrence[token] += 1

            # output
            tmp = self.decode(all_tokens[out_last:])
            if "\ufffd" not in tmp:  # is valid utf-8 string?
                if callback:
                    callback(tmp)
                out_str += tmp
                out_last = i + 1
        return out_str