from typing import Any, List, Union

try:
    import web_rwkv_py as wrp
except ModuleNotFoundError:
    try:
        from . import web_rwkv_py as wrp
    except ImportError:
        raise ModuleNotFoundError(
            "web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
        )


class RWKV:
    def __init__(self, model_path: str, strategy: str = None):
        self.info = wrp.peek_info(model_path)
        self.w = {}  # fake weight
        self.w["emb.weight"] = [0] * self.info.num_vocab
        self.version = str(self.info.version).lower()
        self.wrp = getattr(wrp, self.version)
        self.version = float(self.version.replace("v", ""))

        layer = (
            int(s.lstrip("layer"))
            for s in strategy.split()
            for s in s.split(",")
            if s.startswith("layer")
        )

        chunk_size = (
            int(s.lstrip("chunk"))
            for s in strategy.split()
            for s in s.split(",")
            if s.startswith("chunk")
        )

        args = {
            "file": model_path,
            "turbo": True,
            "quant": next(layer, 31) if "i8" in strategy else 0,
            "quant_nf4": next(layer, 26) if "i4" in strategy else 0,
            "token_chunk_size": next(chunk_size, 32),
            "lora": None,
        }
        self.model = self.wrp.Model(**args)

    def forward(self, tokens: List[int], state: Union[Any, None] = None):
        if type(state).__name__ == "BackedState":  # memory state
            gpu_state = self.wrp.ModelState(self.model, 1)
            gpu_state.load(state)
        else:
            gpu_state = state
        return self.wrp.run_one(self.model, tokens, gpu_state)