diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 957d536..6a0f64a 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -403,16 +403,12 @@ class TextRWKV(AbstractRWKV): + occurrence[n] * self.penalty_alpha_frequency ) + # comment the codes below to get the same generated results as the official RWKV Gradio if i == 0: for token in self.model_tokens: token = int(token) - for xxx in occurrence: - occurrence[xxx] *= self.penalty_decay if token not in self.AVOID_PENALTY_TOKENS: - if token not in occurrence: - occurrence[token] = 1 - else: - occurrence[token] += 1 + self.adjust_occurrence(occurrence, token) # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end def fix_tokens(self, tokens) -> List[int]: