add AVOID_PENALTY_TOKENS
This commit is contained in:
parent
78238c24cf
commit
19eeeab1e1
@ -378,6 +378,12 @@ class TextRWKV(AbstractRWKV):
|
|||||||
dd = self.pipeline.encode(i)
|
dd = self.pipeline.encode(i)
|
||||||
assert len(dd) == 1
|
assert len(dd) == 1
|
||||||
self.AVOID_REPEAT_TOKENS += dd
|
self.AVOID_REPEAT_TOKENS += dd
|
||||||
|
self.AVOID_PENALTY_TOKENS = []
|
||||||
|
AVOID_PENALTY = "\n,.:,。:<>[]{}()/\\|;;" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789
|
||||||
|
for i in AVOID_PENALTY:
|
||||||
|
dd = self.pipeline.encode(i)
|
||||||
|
assert len(dd) == 1
|
||||||
|
self.AVOID_PENALTY_TOKENS += dd
|
||||||
|
|
||||||
self.__preload()
|
self.__preload()
|
||||||
|
|
||||||
@ -391,10 +397,11 @@ class TextRWKV(AbstractRWKV):
|
|||||||
|
|
||||||
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
||||||
for n in occurrence:
|
for n in occurrence:
|
||||||
logits[n] -= (
|
if n not in self.AVOID_PENALTY_TOKENS:
|
||||||
self.penalty_alpha_presence
|
logits[n] -= (
|
||||||
+ occurrence[n] * self.penalty_alpha_frequency
|
self.penalty_alpha_presence
|
||||||
)
|
+ occurrence[n] * self.penalty_alpha_frequency
|
||||||
|
)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
for token in self.model_tokens:
|
for token in self.model_tokens:
|
||||||
@ -639,9 +646,11 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
|||||||
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
||||||
global_var.set(
|
global_var.set(
|
||||||
global_var.Midi_Vocab_Config_Type,
|
global_var.Midi_Vocab_Config_Type,
|
||||||
global_var.MidiVocabConfig.Piano
|
(
|
||||||
if tokenizer_name == "tokenizer-midipiano"
|
global_var.MidiVocabConfig.Piano
|
||||||
else global_var.MidiVocabConfig.Default,
|
if tokenizer_name == "tokenizer-midipiano"
|
||||||
|
else global_var.MidiVocabConfig.Default
|
||||||
|
),
|
||||||
)
|
)
|
||||||
rwkv: AbstractRWKV
|
rwkv: AbstractRWKV
|
||||||
if tokenizer_name in rwkv_map:
|
if tokenizer_name in rwkv_map:
|
||||||
|
Loading…
Reference in New Issue
Block a user