From 19eeeab1e182a9f65d7995d6d09165fafc711946 Mon Sep 17 00:00:00 2001 From: josc146 Date: Sun, 4 Feb 2024 16:49:46 +0800 Subject: [PATCH] add AVOID_PENALTY_TOKENS --- backend-python/utils/rwkv.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index c9781b9..745591f 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -378,6 +378,12 @@ class TextRWKV(AbstractRWKV): dd = self.pipeline.encode(i) assert len(dd) == 1 self.AVOID_REPEAT_TOKENS += dd + self.AVOID_PENALTY_TOKENS = [] + AVOID_PENALTY = "\n,.:,。:<>[]{}()/\\|;;" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789 + for i in AVOID_PENALTY: + dd = self.pipeline.encode(i) + assert len(dd) == 1 + self.AVOID_PENALTY_TOKENS += dd self.__preload() @@ -391,10 +397,11 @@ class TextRWKV(AbstractRWKV): def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int): for n in occurrence: - logits[n] -= ( - self.penalty_alpha_presence - + occurrence[n] * self.penalty_alpha_frequency - ) + if n not in self.AVOID_PENALTY_TOKENS: + logits[n] -= ( + self.penalty_alpha_presence + + occurrence[n] * self.penalty_alpha_frequency + ) if i == 0: for token in self.model_tokens: @@ -639,9 +646,11 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0] global_var.set( global_var.Midi_Vocab_Config_Type, - global_var.MidiVocabConfig.Piano - if tokenizer_name == "tokenizer-midipiano" - else global_var.MidiVocabConfig.Default, + ( + global_var.MidiVocabConfig.Piano + if tokenizer_name == "tokenizer-midipiano" + else global_var.MidiVocabConfig.Default + ), ) rwkv: AbstractRWKV if tokenizer_name in rwkv_map: