140 lines
4.5 KiB
Python
Vendored
140 lines
4.5 KiB
Python
Vendored
import json, time, random, os
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
time_slot = {}
|
|
time_ref = time.time_ns()
|
|
|
|
|
|
def record_time(name):
|
|
if name not in time_slot:
|
|
time_slot[name] = 1e20
|
|
tt = (time.time_ns() - time_ref) / 1e9
|
|
if tt < time_slot[name]:
|
|
time_slot[name] = tt
|
|
|
|
|
|
class TOKENIZER:
|
|
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
|
|
if "list" in str(type(WORD_NAME)):
|
|
self.charMode = False
|
|
if WORD_NAME[0] == WORD_NAME[1]:
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
|
else:
|
|
from transformers import GPT2TokenizerFast
|
|
|
|
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
|
self.vocab_size = len(self.tokenizer)
|
|
else:
|
|
self.charMode = True
|
|
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
|
|
self.word_table = json.load(result_file)
|
|
|
|
self.vocab_size = len(self.word_table)
|
|
|
|
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
|
self.itos = {int(k): v for k, v in self.word_table.items()}
|
|
|
|
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
|
|
|
def refine_context(self, context):
|
|
context = context.strip().split("\n")
|
|
for c in range(len(context)):
|
|
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
|
context = list(filter(lambda c: c != "", context))
|
|
context = "\n" + ("\n".join(context)).strip()
|
|
if context == "":
|
|
context = "\n"
|
|
return context
|
|
|
|
def sample_logits(
|
|
self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
|
|
):
|
|
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
|
lastChar = int(x[-1])
|
|
|
|
probs = F.softmax(out, dim=-1)
|
|
|
|
if self.charMode:
|
|
if self.itos[lastChar] == "\n":
|
|
top_p = top_p_newline
|
|
else:
|
|
top_p = top_p_usual
|
|
else:
|
|
top_p = top_p_usual
|
|
|
|
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
|
probs = probs.numpy()
|
|
sorted_probs = np.sort(probs)[::-1]
|
|
cumulative_probs = np.cumsum(sorted_probs)
|
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
probs[probs < cutoff] = 0
|
|
if temperature != 1.0:
|
|
probs = probs.pow(1.0 / temperature)
|
|
probs = probs / np.sum(probs)
|
|
out = np.random.choice(a=len(probs), p=probs)
|
|
return out
|
|
else:
|
|
sorted_probs = torch.sort(probs, descending=True)[0]
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
probs[probs < cutoff] = 0
|
|
if temperature != 1.0:
|
|
probs = probs.pow(1.0 / temperature)
|
|
out = torch.multinomial(probs, num_samples=1)[0]
|
|
return out
|
|
|
|
|
|
def MaybeIsPrime(number):
|
|
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def FermatPrimalityTest(number):
|
|
if number > 1:
|
|
for time in range(3):
|
|
randomNumber = random.randint(2, number) - 1
|
|
if pow(randomNumber, number - 1, number) != 1:
|
|
return False
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def MillerRabinPrimalityTest(number):
|
|
if number == 2:
|
|
return True
|
|
elif number == 1 or number % 2 == 0:
|
|
return False
|
|
oddPartOfNumber = number - 1
|
|
timesTwoDividNumber = 0
|
|
while oddPartOfNumber % 2 == 0:
|
|
oddPartOfNumber = oddPartOfNumber // 2
|
|
timesTwoDividNumber = timesTwoDividNumber + 1
|
|
|
|
for time in range(3):
|
|
while True:
|
|
randomNumber = random.randint(2, number) - 1
|
|
if randomNumber != 0 and randomNumber != 1:
|
|
break
|
|
|
|
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
|
|
|
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
|
iterationNumber = 1
|
|
|
|
while (iterationNumber <= timesTwoDividNumber - 1) and (
|
|
randomNumberWithPower != number - 1
|
|
):
|
|
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
|
iterationNumber = iterationNumber + 1
|
|
if randomNumberWithPower != (number - 1):
|
|
return False
|
|
|
|
return True
|