support for rwkv-4-world
This commit is contained in:
142
backend-python/rwkv_pip/utils.py
Normal file
142
backend-python/rwkv_pip/utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import os, sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class PIPELINE_ARGS:
|
||||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_p=0.85,
|
||||
top_k=0,
|
||||
alpha_frequency=0.2,
|
||||
alpha_presence=0.2,
|
||||
token_ban=[],
|
||||
token_stop=[],
|
||||
chunk_len=256,
|
||||
):
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
|
||||
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
|
||||
self.token_ban = token_ban # ban the generation of some tokens
|
||||
self.token_stop = token_stop # stop generation whenever you see any token here
|
||||
self.chunk_len = (
|
||||
chunk_len # split input into chunks to save VRAM (shorter -> slower)
|
||||
)
|
||||
|
||||
|
||||
class PIPELINE:
|
||||
def __init__(self, model, WORD_NAME):
|
||||
self.model = model
|
||||
if WORD_NAME == "cl100k_base":
|
||||
import tiktoken
|
||||
|
||||
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
|
||||
elif WORD_NAME == "rwkv_vocab_v20230424":
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from rwkv_tokenizer import TRIE_TOKENIZER
|
||||
|
||||
self.tokenizer = TRIE_TOKENIZER(
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
||||
)
|
||||
else:
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
self.tokenizer = Tokenizer.from_file(WORD_NAME)
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split("\n")
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||
context = list(filter(lambda c: c != "", context))
|
||||
context = "\n" + ("\n".join(context)).strip()
|
||||
if context == "":
|
||||
context = "\n"
|
||||
return context
|
||||
|
||||
def encode(self, x):
|
||||
if "Tokenizer" in str(type(self.tokenizer)):
|
||||
return self.tokenizer.encode(x).ids
|
||||
else:
|
||||
return self.tokenizer.encode(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
top_k = int(top_k)
|
||||
if probs.device == torch.device("cpu"):
|
||||
probs = probs.numpy()
|
||||
sorted_ids = np.argsort(probs)
|
||||
sorted_probs = probs[sorted_ids][::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if top_k < len(probs) and top_k > 0:
|
||||
probs[sorted_ids[:-top_k]] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs ** (1.0 / temperature)
|
||||
probs = probs / np.sum(probs)
|
||||
out = np.random.choice(a=len(probs), p=probs)
|
||||
return int(out)
|
||||
else:
|
||||
sorted_ids = torch.argsort(probs)
|
||||
sorted_probs = probs[sorted_ids]
|
||||
sorted_probs = torch.flip(sorted_probs, dims=(0,))
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if top_k < len(probs) and top_k > 0:
|
||||
probs[sorted_ids[:-top_k]] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs ** (1.0 / temperature)
|
||||
out = torch.multinomial(probs, num_samples=1)[0]
|
||||
return int(out)
|
||||
|
||||
def generate(
|
||||
self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
|
||||
):
|
||||
all_tokens = []
|
||||
out_last = 0
|
||||
out_str = ""
|
||||
occurrence = {}
|
||||
for i in range(token_count):
|
||||
# forward & adjust prob.
|
||||
tokens = self.encode(ctx) if i == 0 else [token]
|
||||
while len(tokens) > 0:
|
||||
out, state = self.model.forward(tokens[: args.chunk_len], state)
|
||||
tokens = tokens[args.chunk_len :]
|
||||
|
||||
for n in args.token_ban:
|
||||
out[n] = -float("inf")
|
||||
for n in occurrence:
|
||||
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
||||
|
||||
# sampler
|
||||
token = self.sample_logits(
|
||||
out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
|
||||
)
|
||||
if token in args.token_stop:
|
||||
break
|
||||
all_tokens += [token]
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
|
||||
# output
|
||||
tmp = self.decode(all_tokens[out_last:])
|
||||
if "\ufffd" not in tmp: # is valid utf-8 string?
|
||||
if callback:
|
||||
callback(tmp)
|
||||
out_str += tmp
|
||||
out_last = i + 1
|
||||
return out_str
|
||||
Reference in New Issue
Block a user