upgrade rwkv pip (0.8.13)

This commit is contained in:
josc146
2023-10-03 13:33:55 +08:00
parent bd4de12e05
commit 79851433f8
18 changed files with 2922 additions and 737 deletions

View File

@@ -16,6 +16,7 @@ class PIPELINE_ARGS:
top_k=0,
alpha_frequency=0.2,
alpha_presence=0.2,
alpha_decay=0.996,
token_ban=[],
token_stop=[],
chunk_len=256,
@@ -25,6 +26,7 @@ class PIPELINE_ARGS:
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.alpha_decay = alpha_decay # gradually decay the penalty
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.chunk_len = (
@@ -84,7 +86,7 @@ class PIPELINE:
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
@@ -98,7 +100,7 @@ class PIPELINE:
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
@@ -133,10 +135,13 @@ class PIPELINE:
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= args.alpha_decay
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# print(occurrence) # debug
# output
tmp = self.decode(all_tokens[out_last:])