diff --git a/backend-python/routes/midi.py b/backend-python/routes/midi.py index 0d1280a..554751b 100644 --- a/backend-python/routes/midi.py +++ b/backend-python/routes/midi.py @@ -37,10 +37,14 @@ def text_to_midi(body: TextToMidiBody): async def midi_to_text(file_data: UploadFile): vocab_config = "backend-python/utils/midi_vocab_config.json" cfg = VocabConfig.from_json(vocab_config) + filter_config = "backend-python/utils/midi_filter_config.json" + filter_cfg = FilterConfig.from_json(filter_config) mid = mido.MidiFile(file=file_data.file) - text = convert_midi_to_str(cfg, mid) + output_list = convert_midi_to_str(cfg, filter_cfg, mid) + if len(output_list) == 0: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad midi file") - return {"text": text} + return {"text": output_list[0]} class TxtToMidiBody(BaseModel): diff --git a/backend-python/utils/midi.py b/backend-python/utils/midi.py index 6993de6..ffee6ed 100644 --- a/backend-python/utils/midi.py +++ b/backend-python/utils/midi.py @@ -52,6 +52,8 @@ class VocabConfig: bin_name_to_program_name: Dict[str, str] # Mapping from program number to instrument name. instrument_names: Dict[str, str] + # Manual override for velocity bins. Each element is the max velocity value for that bin by index. + velocity_bins_override: Optional[List[int]] = None def __post_init__(self): self.validate() @@ -116,6 +118,12 @@ class VocabConfig: raise ValueError("velocity_bins must be at least 2") if len(self.bin_instrument_names) > 16: raise ValueError("bin_instruments must have at most 16 values") + if self.velocity_bins_override: + print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.") + if len(self.velocity_bins_override) != self.velocity_bins: + raise ValueError( + "velocity_bins_override must have same length as velocity_bins" + ) if ( self.ch10_instrument_bin_name and self.ch10_instrument_bin_name not in self.bin_instrument_names @@ -156,6 +164,11 @@ class VocabUtils: def velocity_to_bin(self, velocity: float) -> int: velocity = max(0, min(velocity, self.cfg.velocity_events - 1)) + if self.cfg.velocity_bins_override: + for i, v in enumerate(self.cfg.velocity_bins_override): + if velocity <= v: + return i + return 0 binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1) if self.cfg.velocity_exp == 1.0: return ceil(velocity / binsize) @@ -176,6 +189,8 @@ class VocabUtils: ) def bin_to_velocity(self, bin: int) -> int: + if self.cfg.velocity_bins_override: + return self.cfg.velocity_bins_override[bin] binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1) if self.cfg.velocity_exp == 1.0: return max(0, ceil(bin * binsize - 1)) @@ -358,13 +373,32 @@ class AugmentConfig: ) +@dataclass +class FilterConfig: + # Whether to filter out MIDI files with duplicate MD5 hashes. + deduplicate_md5: bool + # Minimum time delay between notes in a file before splitting into multiple documents. + piece_split_delay: float + # Minimum length of a piece in milliseconds. + min_piece_length: float + + @classmethod + def from_json(cls, path: str): + with open(path, "r") as f: + config = json.load(f) + return cls(**config) + + def mix_volume(velocity: int, volume: int, expression: int) -> float: return velocity * (volume / 127.0) * (expression / 127.0) def convert_midi_to_str( - cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None -) -> str: + cfg: VocabConfig, + filter_cfg: FilterConfig, + mid: mido.MidiFile, + augment: AugmentValues = None, +) -> List[str]: utils = VocabUtils(cfg) if augment is None: augment = AugmentValues.default() @@ -390,7 +424,9 @@ def convert_midi_to_str( } # {channel: {(note, program) -> True}} started_flag = False + output_list = [] output = [""] + output_length_ms = 0.0 token_data_buffer: List[ Tuple[int, int, int, float] ] = [] # need to sort notes between wait tokens @@ -432,16 +468,33 @@ def convert_midi_to_str( token_data_buffer = [] def consume_note_program_data(prog: int, chan: int, note: int, vel: float): - nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer + nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer is_token_valid = ( utils.prog_data_to_token_data(prog, chan, note, vel) is not None ) if not is_token_valid: return + + if delta_time_ms > filter_cfg.piece_split_delay * 1000.0: + # check if any notes are still held + silent = True + for channel in channel_notes.keys(): + if len(channel_notes[channel]) > 0: + silent = False + break + if silent: + flush_token_data_buffer() + output.append("") + if output_length_ms > filter_cfg.min_piece_length * 1000.0: + output_list.append(" ".join(output)) + output = [""] + output_length_ms = 0.0 + started_flag = False if started_flag: wait_tokens = utils.data_to_wait_tokens(delta_time_ms) if len(wait_tokens) > 0: flush_token_data_buffer() + output_length_ms += delta_time_ms output += wait_tokens delta_time_ms = 0.0 token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor)) @@ -510,7 +563,9 @@ def convert_midi_to_str( flush_token_data_buffer() output.append("") - return " ".join(output) + if output_length_ms > filter_cfg.min_piece_length * 1000.0: + output_list.append(" ".join(output)) + return output_list def generate_program_change_messages(cfg: VocabConfig): @@ -633,10 +688,10 @@ def token_to_midi_message( if utils.cfg.decode_fix_repeated_notes: if (channel, note) in state.active_notes: del state.active_notes[(channel, note)] - yield mido.Message( - "note_off", note=note, time=ticks, channel=channel - ), state - ticks = 0 + yield mido.Message( + "note_off", note=note, time=ticks, channel=channel + ), state + ticks = 0 state.active_notes[(channel, note)] = state.total_time yield mido.Message( "note_on", note=note, velocity=velocity, time=ticks, channel=channel diff --git a/backend-python/utils/midi_filter_config.json b/backend-python/utils/midi_filter_config.json new file mode 100644 index 0000000..7763cb0 --- /dev/null +++ b/backend-python/utils/midi_filter_config.json @@ -0,0 +1,5 @@ +{ + "deduplicate_md5": true, + "piece_split_delay": 10.0, + "min_piece_length": 30.0 +} \ No newline at end of file