add better custom tokenizer support and tokenizer-midipiano.json

This commit is contained in:
josc146
2024-02-03 13:04:13 +08:00
parent 947e127e34
commit 5f94d86558
4 changed files with 318 additions and 5 deletions

View File

@@ -546,8 +546,10 @@ class MusicAbcRWKV(AbstractRWKV):
def get_tokenizer(tokenizer_len: int):
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
if tokenizer_len < 20096:
if tokenizer_len < 2176:
return "abc_tokenizer"
if tokenizer_len < 20096:
return tokenizer_dir + "tokenizer-midipiano.json"
if tokenizer_len < 50277:
return tokenizer_dir + "tokenizer-midi.json"
elif tokenizer_len < 65536:
@@ -630,14 +632,27 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
"20B_tokenizer": TextRWKV,
"rwkv_vocab_v20230424": TextRWKV,
"tokenizer-midi": MusicMidiRWKV,
"tokenizer-midipiano": MusicMidiRWKV,
"abc_tokenizer": MusicAbcRWKV,
}
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
global_var.set(
global_var.Midi_Vocab_Config_Type,
global_var.MidiVocabConfig.Piano
if tokenizer_name == "tokenizer-midipiano"
else global_var.MidiVocabConfig.Default,
)
rwkv: AbstractRWKV
if tokenizer_name in rwkv_map:
rwkv = rwkv_map[tokenizer_name](model, pipeline)
else:
rwkv = TextRWKV(model, pipeline)
tokenizer_name = tokenizer_name.lower()
if "music" in tokenizer_name or "midi" in tokenizer_name:
rwkv = MusicMidiRWKV(model, pipeline)
elif "abc" in tokenizer_name:
rwkv = MusicAbcRWKV(model, pipeline)
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
return rwkv