diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index b3defc8..fe8edb8 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -378,14 +378,14 @@ class TextRWKV(AbstractRWKV): dd = self.pipeline.encode(i) assert len(dd) == 1 self.AVOID_REPEAT_TOKENS.add(dd[0]) - self.AVOID_PENALTY_TOKENS = set() - AVOID_PENALTY = ( - "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789 - ) - for i in AVOID_PENALTY: - dd = self.pipeline.encode(i) - assert len(dd) == 1 - self.AVOID_PENALTY_TOKENS.add(dd[0]) + # self.AVOID_PENALTY_TOKENS = set() + # AVOID_PENALTY = ( + # "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789 + # ) + # for i in AVOID_PENALTY: + # dd = self.pipeline.encode(i) + # assert len(dd) == 1 + # self.AVOID_PENALTY_TOKENS.add(dd[0]) self.__preload() @@ -399,11 +399,11 @@ class TextRWKV(AbstractRWKV): def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int): for n in occurrence: - if n not in self.AVOID_PENALTY_TOKENS: - logits[n] -= ( - self.penalty_alpha_presence - + occurrence[n] * self.penalty_alpha_frequency - ) + # if n not in self.AVOID_PENALTY_TOKENS: + logits[n] -= ( + self.penalty_alpha_presence + + occurrence[n] * self.penalty_alpha_frequency + ) if i == 0: for token in self.model_tokens: