add AVOID_PENALTY_TOKENS

This commit is contained in:
josc146 2024-02-04 16:49:46 +08:00
parent 78238c24cf
commit 19eeeab1e1

View File

@ -378,6 +378,12 @@ class TextRWKV(AbstractRWKV):
dd = self.pipeline.encode(i) dd = self.pipeline.encode(i)
assert len(dd) == 1 assert len(dd) == 1
self.AVOID_REPEAT_TOKENS += dd self.AVOID_REPEAT_TOKENS += dd
self.AVOID_PENALTY_TOKENS = []
AVOID_PENALTY = "\n,.:,。:<>[]{}()/\\|;" # \n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789
for i in AVOID_PENALTY:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_PENALTY_TOKENS += dd
self.__preload() self.__preload()
@ -391,10 +397,11 @@ class TextRWKV(AbstractRWKV):
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int): def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence: for n in occurrence:
logits[n] -= ( if n not in self.AVOID_PENALTY_TOKENS:
self.penalty_alpha_presence logits[n] -= (
+ occurrence[n] * self.penalty_alpha_frequency self.penalty_alpha_presence
) + occurrence[n] * self.penalty_alpha_frequency
)
if i == 0: if i == 0:
for token in self.model_tokens: for token in self.model_tokens:
@ -639,9 +646,11 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0] tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
global_var.set( global_var.set(
global_var.Midi_Vocab_Config_Type, global_var.Midi_Vocab_Config_Type,
global_var.MidiVocabConfig.Piano (
if tokenizer_name == "tokenizer-midipiano" global_var.MidiVocabConfig.Piano
else global_var.MidiVocabConfig.Default, if tokenizer_name == "tokenizer-midipiano"
else global_var.MidiVocabConfig.Default
),
) )
rwkv: AbstractRWKV rwkv: AbstractRWKV
if tokenizer_name in rwkv_map: if tokenizer_name in rwkv_map: