improve occurrence[token] condition

This commit is contained in:
josc146 2024-02-29 17:54:33 +08:00
parent c13b28561d
commit b0f2ef65d9

View File

@ -378,14 +378,12 @@ class TextRWKV(AbstractRWKV):
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_REPEAT_TOKENS.add(dd[0])
# self.AVOID_PENALTY_TOKENS = set()
# AVOID_PENALTY = (
# "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789
# )
# for i in AVOID_PENALTY:
# dd = self.pipeline.encode(i)
# assert len(dd) == 1
# self.AVOID_PENALTY_TOKENS.add(dd[0])
self.AVOID_PENALTY_TOKENS = set()
AVOID_PENALTY = '\n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789 '
for i in AVOID_PENALTY:
dd = self.pipeline.encode(i)
if len(dd) == 1:
self.AVOID_PENALTY_TOKENS.add(dd[0])
self.__preload()
@ -410,10 +408,11 @@ class TextRWKV(AbstractRWKV):
token = int(token)
for xxx in occurrence:
occurrence[xxx] *= self.penalty_decay
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
if token not in self.AVOID_PENALTY_TOKENS:
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]: