bump MIDI-LLM-tokenizer (fix note off)

This commit is contained in:
josc146 2023-12-14 13:33:27 +08:00
parent f328e84ea7
commit e0bf44d82f
3 changed files with 74 additions and 10 deletions

View File

@ -37,10 +37,14 @@ def text_to_midi(body: TextToMidiBody):
async def midi_to_text(file_data: UploadFile): async def midi_to_text(file_data: UploadFile):
vocab_config = "backend-python/utils/midi_vocab_config.json" vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config) cfg = VocabConfig.from_json(vocab_config)
filter_config = "backend-python/utils/midi_filter_config.json"
filter_cfg = FilterConfig.from_json(filter_config)
mid = mido.MidiFile(file=file_data.file) mid = mido.MidiFile(file=file_data.file)
text = convert_midi_to_str(cfg, mid) output_list = convert_midi_to_str(cfg, filter_cfg, mid)
if len(output_list) == 0:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad midi file")
return {"text": text} return {"text": output_list[0]}
class TxtToMidiBody(BaseModel): class TxtToMidiBody(BaseModel):

View File

@ -52,6 +52,8 @@ class VocabConfig:
bin_name_to_program_name: Dict[str, str] bin_name_to_program_name: Dict[str, str]
# Mapping from program number to instrument name. # Mapping from program number to instrument name.
instrument_names: Dict[str, str] instrument_names: Dict[str, str]
# Manual override for velocity bins. Each element is the max velocity value for that bin by index.
velocity_bins_override: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
self.validate() self.validate()
@ -116,6 +118,12 @@ class VocabConfig:
raise ValueError("velocity_bins must be at least 2") raise ValueError("velocity_bins must be at least 2")
if len(self.bin_instrument_names) > 16: if len(self.bin_instrument_names) > 16:
raise ValueError("bin_instruments must have at most 16 values") raise ValueError("bin_instruments must have at most 16 values")
if self.velocity_bins_override:
print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
if len(self.velocity_bins_override) != self.velocity_bins:
raise ValueError(
"velocity_bins_override must have same length as velocity_bins"
)
if ( if (
self.ch10_instrument_bin_name self.ch10_instrument_bin_name
and self.ch10_instrument_bin_name not in self.bin_instrument_names and self.ch10_instrument_bin_name not in self.bin_instrument_names
@ -156,6 +164,11 @@ class VocabUtils:
def velocity_to_bin(self, velocity: float) -> int: def velocity_to_bin(self, velocity: float) -> int:
velocity = max(0, min(velocity, self.cfg.velocity_events - 1)) velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
if self.cfg.velocity_bins_override:
for i, v in enumerate(self.cfg.velocity_bins_override):
if velocity <= v:
return i
return 0
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1) binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
if self.cfg.velocity_exp == 1.0: if self.cfg.velocity_exp == 1.0:
return ceil(velocity / binsize) return ceil(velocity / binsize)
@ -176,6 +189,8 @@ class VocabUtils:
) )
def bin_to_velocity(self, bin: int) -> int: def bin_to_velocity(self, bin: int) -> int:
if self.cfg.velocity_bins_override:
return self.cfg.velocity_bins_override[bin]
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1) binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
if self.cfg.velocity_exp == 1.0: if self.cfg.velocity_exp == 1.0:
return max(0, ceil(bin * binsize - 1)) return max(0, ceil(bin * binsize - 1))
@ -358,13 +373,32 @@ class AugmentConfig:
) )
@dataclass
class FilterConfig:
# Whether to filter out MIDI files with duplicate MD5 hashes.
deduplicate_md5: bool
# Minimum time delay between notes in a file before splitting into multiple documents.
piece_split_delay: float
# Minimum length of a piece in milliseconds.
min_piece_length: float
@classmethod
def from_json(cls, path: str):
with open(path, "r") as f:
config = json.load(f)
return cls(**config)
def mix_volume(velocity: int, volume: int, expression: int) -> float: def mix_volume(velocity: int, volume: int, expression: int) -> float:
return velocity * (volume / 127.0) * (expression / 127.0) return velocity * (volume / 127.0) * (expression / 127.0)
def convert_midi_to_str( def convert_midi_to_str(
cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None cfg: VocabConfig,
) -> str: filter_cfg: FilterConfig,
mid: mido.MidiFile,
augment: AugmentValues = None,
) -> List[str]:
utils = VocabUtils(cfg) utils = VocabUtils(cfg)
if augment is None: if augment is None:
augment = AugmentValues.default() augment = AugmentValues.default()
@ -390,7 +424,9 @@ def convert_midi_to_str(
} # {channel: {(note, program) -> True}} } # {channel: {(note, program) -> True}}
started_flag = False started_flag = False
output_list = []
output = ["<start>"] output = ["<start>"]
output_length_ms = 0.0
token_data_buffer: List[ token_data_buffer: List[
Tuple[int, int, int, float] Tuple[int, int, int, float]
] = [] # need to sort notes between wait tokens ] = [] # need to sort notes between wait tokens
@ -432,16 +468,33 @@ def convert_midi_to_str(
token_data_buffer = [] token_data_buffer = []
def consume_note_program_data(prog: int, chan: int, note: int, vel: float): def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
is_token_valid = ( is_token_valid = (
utils.prog_data_to_token_data(prog, chan, note, vel) is not None utils.prog_data_to_token_data(prog, chan, note, vel) is not None
) )
if not is_token_valid: if not is_token_valid:
return return
if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
# check if any notes are still held
silent = True
for channel in channel_notes.keys():
if len(channel_notes[channel]) > 0:
silent = False
break
if silent:
flush_token_data_buffer()
output.append("<end>")
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
output_list.append(" ".join(output))
output = ["<start>"]
output_length_ms = 0.0
started_flag = False
if started_flag: if started_flag:
wait_tokens = utils.data_to_wait_tokens(delta_time_ms) wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
if len(wait_tokens) > 0: if len(wait_tokens) > 0:
flush_token_data_buffer() flush_token_data_buffer()
output_length_ms += delta_time_ms
output += wait_tokens output += wait_tokens
delta_time_ms = 0.0 delta_time_ms = 0.0
token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor)) token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
@ -510,7 +563,9 @@ def convert_midi_to_str(
flush_token_data_buffer() flush_token_data_buffer()
output.append("<end>") output.append("<end>")
return " ".join(output) if output_length_ms > filter_cfg.min_piece_length * 1000.0:
output_list.append(" ".join(output))
return output_list
def generate_program_change_messages(cfg: VocabConfig): def generate_program_change_messages(cfg: VocabConfig):

View File

@ -0,0 +1,5 @@
{
"deduplicate_md5": true,
"piece_split_delay": 10.0,
"min_piece_length": 30.0
}