From b0f2ef65d9ec08992d1547f38f714b66cf7f711f Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 29 Feb 2024 17:54:33 +0800 Subject: [PATCH] improve occurrence[token] condition --- backend-python/utils/rwkv.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index fe8edb8..957d536 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -378,14 +378,12 @@ class TextRWKV(AbstractRWKV): dd = self.pipeline.encode(i) assert len(dd) == 1 self.AVOID_REPEAT_TOKENS.add(dd[0]) - # self.AVOID_PENALTY_TOKENS = set() - # AVOID_PENALTY = ( - # "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789 - # ) - # for i in AVOID_PENALTY: - # dd = self.pipeline.encode(i) - # assert len(dd) == 1 - # self.AVOID_PENALTY_TOKENS.add(dd[0]) + self.AVOID_PENALTY_TOKENS = set() + AVOID_PENALTY = '\n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789 ' + for i in AVOID_PENALTY: + dd = self.pipeline.encode(i) + if len(dd) == 1: + self.AVOID_PENALTY_TOKENS.add(dd[0]) self.__preload() @@ -410,10 +408,11 @@ class TextRWKV(AbstractRWKV): token = int(token) for xxx in occurrence: occurrence[xxx] *= self.penalty_decay - if token not in occurrence: - occurrence[token] = 1 - else: - occurrence[token] += 1 + if token not in self.AVOID_PENALTY_TOKENS: + if token not in occurrence: + occurrence[token] = 1 + else: + occurrence[token] += 1 # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end def fix_tokens(self, tokens) -> List[int]: