improve occurrence[token] condition

This commit is contained in:
josc146 2024-03-01 13:18:03 +08:00
parent ba9aab920e
commit e3baa0da86

View File

@ -403,16 +403,12 @@ class TextRWKV(AbstractRWKV):
+ occurrence[n] * self.penalty_alpha_frequency + occurrence[n] * self.penalty_alpha_frequency
) )
# comment the codes below to get the same generated results as the official RWKV Gradio
if i == 0: if i == 0:
for token in self.model_tokens: for token in self.model_tokens:
token = int(token) token = int(token)
for xxx in occurrence:
occurrence[xxx] *= self.penalty_decay
if token not in self.AVOID_PENALTY_TOKENS: if token not in self.AVOID_PENALTY_TOKENS:
if token not in occurrence: self.adjust_occurrence(occurrence, token)
occurrence[token] = 1
else:
occurrence[token] += 1
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]: def fix_tokens(self, tokens) -> List[int]: