disable AVOID_PENALTY_TOKENS

This commit is contained in:
josc146 2024-02-28 23:12:58 +08:00
parent 225abc5202
commit 18ab8b141f

View File

@ -378,14 +378,14 @@ class TextRWKV(AbstractRWKV):
dd = self.pipeline.encode(i) dd = self.pipeline.encode(i)
assert len(dd) == 1 assert len(dd) == 1
self.AVOID_REPEAT_TOKENS.add(dd[0]) self.AVOID_REPEAT_TOKENS.add(dd[0])
self.AVOID_PENALTY_TOKENS = set() # self.AVOID_PENALTY_TOKENS = set()
AVOID_PENALTY = ( # AVOID_PENALTY = (
"\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789 # "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789
) # )
for i in AVOID_PENALTY: # for i in AVOID_PENALTY:
dd = self.pipeline.encode(i) # dd = self.pipeline.encode(i)
assert len(dd) == 1 # assert len(dd) == 1
self.AVOID_PENALTY_TOKENS.add(dd[0]) # self.AVOID_PENALTY_TOKENS.add(dd[0])
self.__preload() self.__preload()
@ -399,11 +399,11 @@ class TextRWKV(AbstractRWKV):
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int): def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence: for n in occurrence:
if n not in self.AVOID_PENALTY_TOKENS: # if n not in self.AVOID_PENALTY_TOKENS:
logits[n] -= ( logits[n] -= (
self.penalty_alpha_presence self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency + occurrence[n] * self.penalty_alpha_frequency
) )
if i == 0: if i == 0:
for token in self.model_tokens: for token in self.model_tokens: