2023-05-28 12:53:14 +08:00
|
|
|
########################################################################################################
|
|
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
import os, sys
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
class PIPELINE_ARGS:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
temperature=1.0,
|
|
|
|
top_p=0.85,
|
|
|
|
top_k=0,
|
|
|
|
alpha_frequency=0.2,
|
|
|
|
alpha_presence=0.2,
|
2023-10-03 13:33:55 +08:00
|
|
|
alpha_decay=0.996,
|
2023-05-28 12:53:14 +08:00
|
|
|
token_ban=[],
|
|
|
|
token_stop=[],
|
|
|
|
chunk_len=256,
|
|
|
|
):
|
|
|
|
self.temperature = temperature
|
|
|
|
self.top_p = top_p
|
|
|
|
self.top_k = top_k
|
|
|
|
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
|
|
|
|
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
|
2023-10-03 13:33:55 +08:00
|
|
|
self.alpha_decay = alpha_decay # gradually decay the penalty
|
2023-05-28 12:53:14 +08:00
|
|
|
self.token_ban = token_ban # ban the generation of some tokens
|
|
|
|
self.token_stop = token_stop # stop generation whenever you see any token here
|
|
|
|
self.chunk_len = (
|
|
|
|
chunk_len # split input into chunks to save VRAM (shorter -> slower)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class PIPELINE:
|
2023-09-18 17:20:55 +08:00
|
|
|
def __init__(self, model, WORD_NAME: str):
|
2023-05-28 12:53:14 +08:00
|
|
|
self.model = model
|
|
|
|
if WORD_NAME == "cl100k_base":
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
|
|
|
|
elif WORD_NAME == "rwkv_vocab_v20230424":
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from rwkv_tokenizer import TRIE_TOKENIZER
|
|
|
|
|
|
|
|
self.tokenizer = TRIE_TOKENIZER(
|
|
|
|
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
|
|
|
)
|
|
|
|
else:
|
2023-09-18 17:20:55 +08:00
|
|
|
if WORD_NAME.endswith(".txt"):
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from rwkv_tokenizer import TRIE_TOKENIZER
|
2023-05-28 12:53:14 +08:00
|
|
|
|
2023-09-18 17:20:55 +08:00
|
|
|
self.tokenizer = TRIE_TOKENIZER(WORD_NAME)
|
|
|
|
else:
|
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
|
|
|
self.tokenizer = Tokenizer.from_file(WORD_NAME)
|
2023-05-28 12:53:14 +08:00
|
|
|
|
|
|
|
def refine_context(self, context):
|
|
|
|
context = context.strip().split("\n")
|
|
|
|
for c in range(len(context)):
|
|
|
|
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
|
|
|
context = list(filter(lambda c: c != "", context))
|
|
|
|
context = "\n" + ("\n".join(context)).strip()
|
|
|
|
if context == "":
|
|
|
|
context = "\n"
|
|
|
|
return context
|
|
|
|
|
|
|
|
def encode(self, x):
|
|
|
|
if "Tokenizer" in str(type(self.tokenizer)):
|
|
|
|
return self.tokenizer.encode(x).ids
|
|
|
|
else:
|
|
|
|
return self.tokenizer.encode(x)
|
|
|
|
|
|
|
|
def decode(self, x):
|
|
|
|
return self.tokenizer.decode(x)
|
|
|
|
|
|
|
|
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
|
|
|
probs = F.softmax(logits.float(), dim=-1)
|
|
|
|
top_k = int(top_k)
|
|
|
|
if probs.device == torch.device("cpu"):
|
|
|
|
probs = probs.numpy()
|
|
|
|
sorted_ids = np.argsort(probs)
|
|
|
|
sorted_probs = probs[sorted_ids][::-1]
|
|
|
|
cumulative_probs = np.cumsum(sorted_probs)
|
2023-10-03 13:33:55 +08:00
|
|
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
|
2023-05-28 12:53:14 +08:00
|
|
|
probs[probs < cutoff] = 0
|
|
|
|
if top_k < len(probs) and top_k > 0:
|
|
|
|
probs[sorted_ids[:-top_k]] = 0
|
|
|
|
if temperature != 1.0:
|
|
|
|
probs = probs ** (1.0 / temperature)
|
|
|
|
probs = probs / np.sum(probs)
|
|
|
|
out = np.random.choice(a=len(probs), p=probs)
|
|
|
|
return int(out)
|
|
|
|
else:
|
|
|
|
sorted_ids = torch.argsort(probs)
|
|
|
|
sorted_probs = probs[sorted_ids]
|
|
|
|
sorted_probs = torch.flip(sorted_probs, dims=(0,))
|
|
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
2023-10-03 13:33:55 +08:00
|
|
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
|
2023-05-28 12:53:14 +08:00
|
|
|
probs[probs < cutoff] = 0
|
|
|
|
if top_k < len(probs) and top_k > 0:
|
|
|
|
probs[sorted_ids[:-top_k]] = 0
|
|
|
|
if temperature != 1.0:
|
|
|
|
probs = probs ** (1.0 / temperature)
|
|
|
|
out = torch.multinomial(probs, num_samples=1)[0]
|
|
|
|
return int(out)
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
|
|
|
|
):
|
|
|
|
all_tokens = []
|
|
|
|
out_last = 0
|
|
|
|
out_str = ""
|
|
|
|
occurrence = {}
|
|
|
|
for i in range(token_count):
|
|
|
|
# forward & adjust prob.
|
|
|
|
tokens = self.encode(ctx) if i == 0 else [token]
|
|
|
|
while len(tokens) > 0:
|
|
|
|
out, state = self.model.forward(tokens[: args.chunk_len], state)
|
|
|
|
tokens = tokens[args.chunk_len :]
|
|
|
|
|
|
|
|
for n in args.token_ban:
|
|
|
|
out[n] = -float("inf")
|
|
|
|
for n in occurrence:
|
|
|
|
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
|
|
|
|
|
|
|
# sampler
|
|
|
|
token = self.sample_logits(
|
|
|
|
out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
|
|
|
|
)
|
|
|
|
if token in args.token_stop:
|
|
|
|
break
|
|
|
|
all_tokens += [token]
|
2023-10-03 13:33:55 +08:00
|
|
|
for xxx in occurrence:
|
|
|
|
occurrence[xxx] *= args.alpha_decay
|
2023-05-28 12:53:14 +08:00
|
|
|
if token not in occurrence:
|
|
|
|
occurrence[token] = 1
|
|
|
|
else:
|
|
|
|
occurrence[token] += 1
|
2023-10-03 13:33:55 +08:00
|
|
|
# print(occurrence) # debug
|
2023-05-28 12:53:14 +08:00
|
|
|
|
|
|
|
# output
|
|
|
|
tmp = self.decode(all_tokens[out_last:])
|
|
|
|
if "\ufffd" not in tmp: # is valid utf-8 string?
|
|
|
|
if callback:
|
|
|
|
callback(tmp)
|
|
|
|
out_str += tmp
|
|
|
|
out_last = i + 1
|
|
|
|
return out_str
|