from abc import ABC, abstractmethod
from enum import Enum, auto
import os
import pathlib
import copy
import re
from typing import Dict, Iterable, List, Tuple, Union, Type
from utils.log import quick_log
from fastapi import HTTPException
from pydantic import BaseModel, Field
import numpy as np
from routes import state_cache
import global_var


END_OF_TEXT = 0
END_OF_LINE_DOUBLE = 535


os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"


class RWKVType(Enum):
    NoneType = auto()
    Raven = auto()
    World = auto()
    Music = auto()


class AbstractRWKV(ABC):
    def __init__(self, model, pipeline):
        self.name = "rwkv"
        self.model = model
        self.pipeline = pipeline
        self.model_state = None
        self.model_tokens = []
        self.rwkv_type: RWKVType = RWKVType.NoneType
        self.tokenizer_len = len(model.w["emb.weight"])

        self.max_tokens_per_generation = 500
        self.temperature = 1
        self.top_p = 0.3
        self.top_k = 0
        self.penalty_alpha_presence = 0
        self.penalty_alpha_frequency = 1

    @abstractmethod
    def adjust_occurrence(self, occurrence: Dict, token: int):
        pass

    @abstractmethod
    def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
        pass

    # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
    @abstractmethod
    def fix_tokens(self, tokens) -> List[int]:
        pass

    @abstractmethod
    def run_rnn(
        self, _tokens: List[str], newline_adj: int = 0
    ) -> Tuple[List[float], int]:
        pass

    @abstractmethod
    def delta_postprocess(self, delta: str) -> str:
        pass

    def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
        if fast_mode:
            embedding, token_len = self.__fast_embedding(
                self.fix_tokens(self.pipeline.encode(input)), None
            )
        else:
            self.model_state = None
            self.model_tokens = []
            _, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
            embedding = self.model_state[-11].tolist()
        embedding = (embedding / np.linalg.norm(embedding)).tolist()
        return embedding, token_len

    def __fast_embedding(self, tokens: List[str], state):
        import torch

        tokens = [int(x) for x in tokens]
        token_len = len(tokens)
        self = self.model

        with torch.no_grad():
            w = self.w
            args = self.args

            if state == None:
                state = [None] * args.n_layer * 5
                for i in range(
                    args.n_layer
                ):  # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
                    dd = self.strategy[i]
                    dev = dd.device
                    atype = dd.atype
                    state[i * 5 + 0] = torch.zeros(
                        args.n_embd, dtype=atype, requires_grad=False, device=dev
                    ).contiguous()
                    state[i * 5 + 1] = torch.zeros(
                        args.n_embd, dtype=torch.float, requires_grad=False, device=dev
                    ).contiguous()
                    state[i * 5 + 2] = torch.zeros(
                        args.n_embd, dtype=torch.float, requires_grad=False, device=dev
                    ).contiguous()
                    state[i * 5 + 3] = (
                        torch.zeros(
                            args.n_embd,
                            dtype=torch.float,
                            requires_grad=False,
                            device=dev,
                        ).contiguous()
                        - 1e30
                    )
                    state[i * 5 + 4] = torch.zeros(
                        args.n_embd, dtype=atype, requires_grad=False, device=dev
                    ).contiguous()

                    break

            seq_mode = len(tokens) > 1

            x = w["emb.weight"][tokens if seq_mode else tokens[0]]

            for i in range(args.n_layer):
                bbb = f"blocks.{i}."
                att = f"blocks.{i}.att."
                ffn = f"blocks.{i}.ffn."
                dd = self.strategy[i]
                dev = dd.device
                atype = dd.atype
                wtype = dd.wtype
                if seq_mode:
                    if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
                        ATT = (
                            self.cuda_att_seq
                            if wtype != torch.uint8
                            else self.cuda_att_seq_i8
                        )
                    else:
                        ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
                    FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
                else:
                    ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
                    FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8

                x = x.to(dtype=atype, device=dev)

                kw = w[f"{att}key.weight"]
                vw = w[f"{att}value.weight"]
                rw = w[f"{att}receptance.weight"]
                ow = w[f"{att}output.weight"]
                if dd.stream:
                    kw = kw.to(device=dev, non_blocking=True)
                    vw = vw.to(device=dev, non_blocking=True)
                    rw = rw.to(device=dev, non_blocking=True)
                    ow = ow.to(device=dev, non_blocking=True)
                kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
                krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
                kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
                kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
                vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
                vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
                vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
                vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
                rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
                rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
                rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
                rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
                omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
                orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
                omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
                ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
                (
                    x,
                    state[i * 5 + 0],
                    state[i * 5 + 1],
                    state[i * 5 + 2],
                    state[i * 5 + 3],
                ) = ATT(
                    x,
                    state[i * 5 + 0],
                    state[i * 5 + 1],
                    state[i * 5 + 2],
                    state[i * 5 + 3],
                    w[f"{bbb}ln1.weight"],
                    w[f"{bbb}ln1.bias"],
                    w[f"{att}time_mix_k"],
                    w[f"{att}time_mix_v"],
                    w[f"{att}time_mix_r"],
                    w[f"{att}time_decay"],
                    w[f"{att}time_first"],
                    kw,
                    vw,
                    rw,
                    ow,
                    kmx,
                    krx,
                    kmy,
                    kry,
                    vmx,
                    vrx,
                    vmy,
                    vry,
                    rmx,
                    rrx,
                    rmy,
                    rry,
                    omx,
                    orx,
                    omy,
                    ory,
                )

                return state[0].tolist(), token_len

    def generate(
        self, prompt: str, stop: Union[str, List[str], None] = None
    ) -> Iterable[Tuple[str, str, int, int]]:
        quick_log(None, None, "Generation Prompt:\n" + prompt)
        cache = None
        delta_prompt = prompt
        try:
            cache = state_cache.longest_prefix_state(
                state_cache.LongestPrefixStateBody(prompt=prompt), None
            )
        except HTTPException:
            pass
        if cache is None or cache["prompt"] == "":
            self.model_state = None
            self.model_tokens = []
        else:
            delta_prompt = prompt[len(cache["prompt"]) :]
            self.model_state = copy.deepcopy(cache["state"])
            self.model_tokens = copy.deepcopy(cache["tokens"])
            logits = copy.deepcopy(cache["logits"])

        prompt_token_len = 0
        if delta_prompt != "":
            logits, prompt_token_len = self.run_rnn(
                self.fix_tokens(self.pipeline.encode(delta_prompt))
            )
            try:
                state_cache.add_state(
                    state_cache.AddStateBody(
                        prompt=prompt,
                        tokens=self.model_tokens,
                        state=self.model_state,
                        logits=logits,
                    )
                )
            except HTTPException:
                pass

        begin = len(self.model_tokens)
        out_last = begin

        occurrence: Dict = {}

        completion_token_len = 0
        response = ""
        for i in range(self.max_tokens_per_generation):
            self.adjust_forward_logits(logits, occurrence, i)

            token = self.pipeline.sample_logits(
                logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
            )

            if token == END_OF_TEXT:
                yield response, "", prompt_token_len, completion_token_len
                break

            self.adjust_occurrence(occurrence, token)

            logits, _ = self.run_rnn([token])
            completion_token_len = completion_token_len + 1
            delta: str = self.delta_postprocess(
                self.pipeline.decode(self.model_tokens[out_last:])
            )
            if "\ufffd" not in delta:  # avoid utf-8 display issues
                response += delta
                if stop is not None:
                    if type(stop) == str:
                        if stop in response:
                            try:
                                state_cache.add_state(
                                    state_cache.AddStateBody(
                                        prompt=prompt + response,
                                        tokens=self.model_tokens,
                                        state=self.model_state,
                                        logits=logits,
                                    )
                                )
                            except HTTPException:
                                pass
                            response = response.split(stop)[0]
                            yield response, "", prompt_token_len, completion_token_len
                            break
                    elif type(stop) == list:
                        stop_exist_regex = "|".join(stop)
                        matched = re.search(stop_exist_regex, response)
                        if matched:
                            try:
                                state_cache.add_state(
                                    state_cache.AddStateBody(
                                        prompt=prompt + response,
                                        tokens=self.model_tokens,
                                        state=self.model_state,
                                        logits=logits,
                                    )
                                )
                            except HTTPException:
                                pass
                            response = response.split(matched.group())[0]
                            yield response, "", prompt_token_len, completion_token_len
                            break
                out_last = begin + i + 1
                if i == self.max_tokens_per_generation - 1:
                    try:
                        state_cache.add_state(
                            state_cache.AddStateBody(
                                prompt=prompt + response,
                                tokens=self.model_tokens,
                                state=self.model_state,
                                logits=logits,
                            )
                        )
                    except HTTPException:
                        pass
                yield response, delta, prompt_token_len, completion_token_len


class TextRWKV(AbstractRWKV):
    def __init__(self, model, pipeline) -> None:
        super().__init__(model, pipeline)

        self.CHUNK_LEN = 256

        self.max_tokens_per_generation = 500
        self.temperature = 1
        self.top_p = 0.3
        self.top_k = 0
        self.penalty_alpha_presence = 0
        self.penalty_alpha_frequency = 1

        self.interface = ":"
        if self.tokenizer_len < 65536:
            self.rwkv_type = RWKVType.Raven
            self.user = "Bob"
            self.bot = "Alice"
            self.END_OF_LINE = 187
        else:
            self.rwkv_type = RWKVType.World
            self.user = "User"
            self.bot = "Assistant"
            self.END_OF_LINE = 11

        self.AVOID_REPEAT_TOKENS = []
        AVOID_REPEAT = ",:?!"
        for i in AVOID_REPEAT:
            dd = self.pipeline.encode(i)
            assert len(dd) == 1
            self.AVOID_REPEAT_TOKENS += dd

        self.__preload()

    def adjust_occurrence(self, occurrence: Dict, token: int):
        for xxx in occurrence:
            occurrence[xxx] *= 0.996
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1

    def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
        for n in occurrence:
            logits[n] -= (
                self.penalty_alpha_presence
                + occurrence[n] * self.penalty_alpha_frequency
            )

        if i == 0:
            for token in self.model_tokens:
                token = int(token)
                for xxx in occurrence:
                    occurrence[xxx] *= 0.996
                if token not in occurrence:
                    occurrence[token] = 1
                else:
                    occurrence[token] += 1

    # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
    def fix_tokens(self, tokens) -> List[int]:
        if self.rwkv_type == RWKVType.World:
            return tokens
        if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
            tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
        return tokens

    def run_rnn(
        self, _tokens: List[str], newline_adj: int = 0
    ) -> Tuple[List[float], int]:
        tokens = [int(x) for x in _tokens]
        token_len = len(tokens)
        self.model_tokens += tokens

        while len(tokens) > 0:
            out, self.model_state = self.model.forward(
                tokens[: self.CHUNK_LEN], self.model_state
            )
            tokens = tokens[self.CHUNK_LEN :]

        out[self.END_OF_LINE] += newline_adj  # adjust \n probability

        if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
            out[self.model_tokens[-1]] = -999999999
        return out, token_len

    def delta_postprocess(self, delta: str) -> str:
        return delta

    def __preload(self):
        interface = self.interface
        user = self.user
        bot = self.bot
        preset_system = (
            f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
            if self.rwkv_type == RWKVType.Raven
            else (
                f"{user}{interface} hi\n\n{bot}{interface} Hi. "
                + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
            )
        )
        logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
        try:
            state_cache.add_state(
                state_cache.AddStateBody(
                    prompt=preset_system,
                    tokens=self.model_tokens,
                    state=self.model_state,
                    logits=logits,
                )
            )
        except HTTPException:
            pass


class MusicRWKV(AbstractRWKV):
    def __init__(self, model, pipeline):
        super().__init__(model, pipeline)

        self.max_tokens_per_generation = 500
        self.temperature = 1
        self.top_p = 0.8
        self.top_k = 8

        self.rwkv_type = RWKVType.Music

    def adjust_occurrence(self, occurrence: Dict, token: int):
        for n in occurrence:
            occurrence[n] *= 0.997  #### decay repetition penalty
        if token >= 128 or token == 127:
            occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
        else:
            occurrence[token] = 0.3 + (occurrence[token] if token in occurrence else 0)

    def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
        for n in occurrence:
            logits[n] -= 0 + occurrence[n] * 0.5

        logits[0] += (i - 2000) / 500  # try not to be too short or too long
        logits[127] -= 1  # avoid "t125"

    def fix_tokens(self, tokens) -> List[int]:
        return tokens

    def run_rnn(
        self, _tokens: List[str], newline_adj: int = 0
    ) -> Tuple[List[float], int]:
        tokens = [int(x) for x in _tokens]
        token_len = len(tokens)
        self.model_tokens += tokens
        out, self.model_state = self.model.forward(tokens, self.model_state)
        return out, token_len

    def delta_postprocess(self, delta: str) -> str:
        return " " + delta


def get_tokenizer(tokenizer_len: int):
    tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
    if tokenizer_len < 50277:
        return tokenizer_dir + "tokenizer-midi.json"
    elif tokenizer_len < 65536:
        return tokenizer_dir + "20B_tokenizer.json"
    else:
        return "rwkv_vocab_v20230424"


def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
    rwkv_beta = global_var.get(global_var.Args).rwkv_beta

    if "midi" in model.lower() or "abc" in model.lower():
        os.environ["RWKV_RESCALE_LAYER"] = "999"

    # dynamic import to make RWKV_CUDA_ON work
    if rwkv_beta:
        from rwkv_pip.beta.model import (
            RWKV as Model,
        )
    else:
        from rwkv_pip.model import (
            RWKV as Model,
        )
    from rwkv_pip.utils import PIPELINE

    filename, _ = os.path.splitext(os.path.basename(model))
    model = Model(model, strategy)
    if not tokenizer:
        tokenizer = get_tokenizer(len(model.w["emb.weight"]))
    pipeline = PIPELINE(model, tokenizer)

    rwkv_map: dict[str, Type[AbstractRWKV]] = {
        "20B_tokenizer": TextRWKV,
        "rwkv_vocab_v20230424": TextRWKV,
        "tokenizer-midi": MusicRWKV,
    }
    tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
    rwkv: AbstractRWKV
    if tokenizer_name in rwkv_map:
        rwkv = rwkv_map[tokenizer_name](model, pipeline)
    else:
        rwkv = TextRWKV(model, pipeline)
    rwkv.name = filename

    return rwkv


class ModelConfigBody(BaseModel):
    max_tokens: int = Field(default=None, gt=0, le=102400)
    temperature: float = Field(default=None, ge=0, le=2)
    top_p: float = Field(default=None, ge=0, le=1)
    presence_penalty: float = Field(default=None, ge=-2, le=2)
    frequency_penalty: float = Field(default=None, ge=-2, le=2)

    model_config = {
        "json_schema_extra": {
            "example": {
                "max_tokens": 1000,
                "temperature": 1.2,
                "top_p": 0.5,
                "presence_penalty": 0.4,
                "frequency_penalty": 0.4,
            }
        }
    }


def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
    if body.max_tokens is not None:
        model.max_tokens_per_generation = body.max_tokens
    if body.temperature is not None:
        if body.temperature < 0.1:
            model.temperature = 0.1
        else:
            model.temperature = body.temperature
    if body.top_p is not None:
        model.top_p = body.top_p
    if body.presence_penalty is not None:
        model.penalty_alpha_presence = body.presence_penalty
    if body.frequency_penalty is not None:
        model.penalty_alpha_frequency = body.frequency_penalty


def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
    return ModelConfigBody(
        max_tokens=model.max_tokens_per_generation,
        temperature=model.temperature,
        top_p=model.top_p,
        presence_penalty=model.penalty_alpha_presence,
        frequency_penalty=model.penalty_alpha_frequency,
    )