RWKV-Runner/backend-python/rwkv_pip/utils.py
2024-02-21 23:50:05 +08:00

195 lines
6.9 KiB
Python
Vendored

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, sys
import numpy as np
import torch
from torch.nn import functional as F
class PIPELINE_ARGS:
def __init__(
self,
temperature=1.0,
top_p=0.85,
top_k=0,
alpha_frequency=0.2,
alpha_presence=0.2,
alpha_decay=0.996,
token_ban=[],
token_stop=[],
chunk_len=256,
):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.alpha_decay = alpha_decay # gradually decay the penalty
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.chunk_len = (
chunk_len # split input into chunks to save VRAM (shorter -> slower)
)
class ABC_TOKENIZER:
def __init__(self):
self.pad_token_id = 0
self.bos_token_id = 2
self.eos_token_id = 3
def encode(self, text):
ids = [ord(c) for c in text]
return ids
def decode(self, ids):
txt = "".join(
chr(idx) if idx > self.eos_token_id else ""
for idx in ids
if idx != self.eos_token_id
)
return txt
class PIPELINE:
def __init__(self, model, WORD_NAME: str):
self.model = model
if WORD_NAME == "cl100k_base":
import tiktoken
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
elif WORD_NAME == "rwkv_vocab_v20230424":
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from rwkv_tokenizer import TRIE_TOKENIZER
self.tokenizer = TRIE_TOKENIZER(
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
)
elif WORD_NAME == "abc_tokenizer":
self.tokenizer = ABC_TOKENIZER()
else:
if WORD_NAME.endswith(".txt"):
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from rwkv_tokenizer import TRIE_TOKENIZER
self.tokenizer = TRIE_TOKENIZER(WORD_NAME)
else:
from tokenizers import Tokenizer
self.tokenizer = Tokenizer.from_file(WORD_NAME)
def refine_context(self, context):
context = context.strip().split("\n")
for c in range(len(context)):
context[c] = context[c].strip().strip("\u3000").strip("\r")
context = list(filter(lambda c: c != "", context))
context = "\n" + ("\n".join(context)).strip()
if context == "":
context = "\n"
return context
def encode(self, x):
if "Tokenizer" in str(type(self.tokenizer)):
return self.tokenizer.encode(x).ids
else:
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def np_softmax(self, x: np.ndarray, axis: int):
x -= x.max(axis=axis, keepdims=True)
e: np.ndarray = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
if type(logits) == list:
logits = np.array(logits)
np_logits = type(logits) == np.ndarray
if np_logits:
probs = self.np_softmax(logits, axis=-1)
else:
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if np_logits or probs.device.type in ["cpu", "privateuseone"]:
if not np_logits:
probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
def generate(
self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
):
all_tokens = []
out_last = 0
out_str = ""
occurrence = {}
for i in range(token_count):
# forward & adjust prob.
tokens = self.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
out, state = self.model.forward(tokens[: args.chunk_len], state)
tokens = tokens[args.chunk_len :]
for n in args.token_ban:
out[n] = -float("inf")
for n in occurrence:
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
# sampler
token = self.sample_logits(
out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= args.alpha_decay
ttt = self.decode([token])
www = 1
if ttt in " \t0123456789":
www = 0
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
# www = 0.5
if token not in occurrence:
occurrence[token] = www
else:
occurrence[token] += www
# print(occurrence) # debug
# output
tmp = self.decode(all_tokens[out_last:])
if "\ufffd" not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
out_last = i + 1
return out_str