import json, time, random, os import numpy as np import torch from torch.nn import functional as F time_slot = {} time_ref = time.time_ns() def record_time(name): if name not in time_slot: time_slot[name] = 1e20 tt = (time.time_ns() - time_ref) / 1e9 if tt < time_slot[name]: time_slot[name] = tt class TOKENIZER: def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): if "list" in str(type(WORD_NAME)): self.charMode = False if WORD_NAME[0] == WORD_NAME[1]: from transformers import PreTrainedTokenizerFast self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) else: from transformers import GPT2TokenizerFast self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.vocab_size = len(self.tokenizer) else: self.charMode = True with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) self.stoi = {v: int(k) for k, v in self.word_table.items()} self.itos = {int(k): v for k, v in self.word_table.items()} self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): context = context.strip().split("\n") for c in range(len(context)): context[c] = context[c].strip().strip("\u3000").strip("\r") context = list(filter(lambda c: c != "", context)) context = "\n" + ("\n".join(context)).strip() if context == "": context = "\n" return context def sample_logits( self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) probs = F.softmax(out, dim=-1) if self.charMode: if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual else: top_p = top_p_usual if os.environ["RWKV_RUN_DEVICE"] == "cpu": probs = probs.numpy() sorted_probs = np.sort(probs)[::-1] cumulative_probs = np.cumsum(sorted_probs) cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) probs[probs < cutoff] = 0 if temperature != 1.0: probs = probs.pow(1.0 / temperature) probs = probs / np.sum(probs) out = np.random.choice(a=len(probs), p=probs) return out else: sorted_probs = torch.sort(probs, descending=True)[0] cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) probs[probs < cutoff] = 0 if temperature != 1.0: probs = probs.pow(1.0 / temperature) out = torch.multinomial(probs, num_samples=1)[0] return out def MaybeIsPrime(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): return True else: return False def FermatPrimalityTest(number): if number > 1: for time in range(3): randomNumber = random.randint(2, number) - 1 if pow(randomNumber, number - 1, number) != 1: return False return True else: return False def MillerRabinPrimalityTest(number): if number == 2: return True elif number == 1 or number % 2 == 0: return False oddPartOfNumber = number - 1 timesTwoDividNumber = 0 while oddPartOfNumber % 2 == 0: oddPartOfNumber = oddPartOfNumber // 2 timesTwoDividNumber = timesTwoDividNumber + 1 for time in range(3): while True: randomNumber = random.randint(2, number) - 1 if randomNumber != 0 and randomNumber != 1: break randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): iterationNumber = 1 while (iterationNumber <= timesTwoDividNumber - 1) and ( randomNumberWithPower != number - 1 ): randomNumberWithPower = pow(randomNumberWithPower, 2, number) iterationNumber = iterationNumber + 1 if randomNumberWithPower != (number - 1): return False return True