bump MIDI-LLM-tokenizer (fix note off)
This commit is contained in:
parent
f328e84ea7
commit
e0bf44d82f
@ -37,10 +37,14 @@ def text_to_midi(body: TextToMidiBody):
|
||||
async def midi_to_text(file_data: UploadFile):
|
||||
vocab_config = "backend-python/utils/midi_vocab_config.json"
|
||||
cfg = VocabConfig.from_json(vocab_config)
|
||||
filter_config = "backend-python/utils/midi_filter_config.json"
|
||||
filter_cfg = FilterConfig.from_json(filter_config)
|
||||
mid = mido.MidiFile(file=file_data.file)
|
||||
text = convert_midi_to_str(cfg, mid)
|
||||
output_list = convert_midi_to_str(cfg, filter_cfg, mid)
|
||||
if len(output_list) == 0:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad midi file")
|
||||
|
||||
return {"text": text}
|
||||
return {"text": output_list[0]}
|
||||
|
||||
|
||||
class TxtToMidiBody(BaseModel):
|
||||
|
71
backend-python/utils/midi.py
vendored
71
backend-python/utils/midi.py
vendored
@ -52,6 +52,8 @@ class VocabConfig:
|
||||
bin_name_to_program_name: Dict[str, str]
|
||||
# Mapping from program number to instrument name.
|
||||
instrument_names: Dict[str, str]
|
||||
# Manual override for velocity bins. Each element is the max velocity value for that bin by index.
|
||||
velocity_bins_override: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
@ -116,6 +118,12 @@ class VocabConfig:
|
||||
raise ValueError("velocity_bins must be at least 2")
|
||||
if len(self.bin_instrument_names) > 16:
|
||||
raise ValueError("bin_instruments must have at most 16 values")
|
||||
if self.velocity_bins_override:
|
||||
print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
|
||||
if len(self.velocity_bins_override) != self.velocity_bins:
|
||||
raise ValueError(
|
||||
"velocity_bins_override must have same length as velocity_bins"
|
||||
)
|
||||
if (
|
||||
self.ch10_instrument_bin_name
|
||||
and self.ch10_instrument_bin_name not in self.bin_instrument_names
|
||||
@ -156,6 +164,11 @@ class VocabUtils:
|
||||
|
||||
def velocity_to_bin(self, velocity: float) -> int:
|
||||
velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
|
||||
if self.cfg.velocity_bins_override:
|
||||
for i, v in enumerate(self.cfg.velocity_bins_override):
|
||||
if velocity <= v:
|
||||
return i
|
||||
return 0
|
||||
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
||||
if self.cfg.velocity_exp == 1.0:
|
||||
return ceil(velocity / binsize)
|
||||
@ -176,6 +189,8 @@ class VocabUtils:
|
||||
)
|
||||
|
||||
def bin_to_velocity(self, bin: int) -> int:
|
||||
if self.cfg.velocity_bins_override:
|
||||
return self.cfg.velocity_bins_override[bin]
|
||||
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
||||
if self.cfg.velocity_exp == 1.0:
|
||||
return max(0, ceil(bin * binsize - 1))
|
||||
@ -358,13 +373,32 @@ class AugmentConfig:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterConfig:
|
||||
# Whether to filter out MIDI files with duplicate MD5 hashes.
|
||||
deduplicate_md5: bool
|
||||
# Minimum time delay between notes in a file before splitting into multiple documents.
|
||||
piece_split_delay: float
|
||||
# Minimum length of a piece in milliseconds.
|
||||
min_piece_length: float
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: str):
|
||||
with open(path, "r") as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
|
||||
def mix_volume(velocity: int, volume: int, expression: int) -> float:
|
||||
return velocity * (volume / 127.0) * (expression / 127.0)
|
||||
|
||||
|
||||
def convert_midi_to_str(
|
||||
cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None
|
||||
) -> str:
|
||||
cfg: VocabConfig,
|
||||
filter_cfg: FilterConfig,
|
||||
mid: mido.MidiFile,
|
||||
augment: AugmentValues = None,
|
||||
) -> List[str]:
|
||||
utils = VocabUtils(cfg)
|
||||
if augment is None:
|
||||
augment = AugmentValues.default()
|
||||
@ -390,7 +424,9 @@ def convert_midi_to_str(
|
||||
} # {channel: {(note, program) -> True}}
|
||||
started_flag = False
|
||||
|
||||
output_list = []
|
||||
output = ["<start>"]
|
||||
output_length_ms = 0.0
|
||||
token_data_buffer: List[
|
||||
Tuple[int, int, int, float]
|
||||
] = [] # need to sort notes between wait tokens
|
||||
@ -432,16 +468,33 @@ def convert_midi_to_str(
|
||||
token_data_buffer = []
|
||||
|
||||
def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
|
||||
nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer
|
||||
nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
|
||||
is_token_valid = (
|
||||
utils.prog_data_to_token_data(prog, chan, note, vel) is not None
|
||||
)
|
||||
if not is_token_valid:
|
||||
return
|
||||
|
||||
if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
|
||||
# check if any notes are still held
|
||||
silent = True
|
||||
for channel in channel_notes.keys():
|
||||
if len(channel_notes[channel]) > 0:
|
||||
silent = False
|
||||
break
|
||||
if silent:
|
||||
flush_token_data_buffer()
|
||||
output.append("<end>")
|
||||
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
|
||||
output_list.append(" ".join(output))
|
||||
output = ["<start>"]
|
||||
output_length_ms = 0.0
|
||||
started_flag = False
|
||||
if started_flag:
|
||||
wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
|
||||
if len(wait_tokens) > 0:
|
||||
flush_token_data_buffer()
|
||||
output_length_ms += delta_time_ms
|
||||
output += wait_tokens
|
||||
delta_time_ms = 0.0
|
||||
token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
|
||||
@ -510,7 +563,9 @@ def convert_midi_to_str(
|
||||
|
||||
flush_token_data_buffer()
|
||||
output.append("<end>")
|
||||
return " ".join(output)
|
||||
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
|
||||
output_list.append(" ".join(output))
|
||||
return output_list
|
||||
|
||||
|
||||
def generate_program_change_messages(cfg: VocabConfig):
|
||||
@ -633,10 +688,10 @@ def token_to_midi_message(
|
||||
if utils.cfg.decode_fix_repeated_notes:
|
||||
if (channel, note) in state.active_notes:
|
||||
del state.active_notes[(channel, note)]
|
||||
yield mido.Message(
|
||||
"note_off", note=note, time=ticks, channel=channel
|
||||
), state
|
||||
ticks = 0
|
||||
yield mido.Message(
|
||||
"note_off", note=note, time=ticks, channel=channel
|
||||
), state
|
||||
ticks = 0
|
||||
state.active_notes[(channel, note)] = state.total_time
|
||||
yield mido.Message(
|
||||
"note_on", note=note, velocity=velocity, time=ticks, channel=channel
|
||||
|
5
backend-python/utils/midi_filter_config.json
Normal file
5
backend-python/utils/midi_filter_config.json
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"deduplicate_md5": true,
|
||||
"piece_split_delay": 10.0,
|
||||
"min_piece_length": 30.0
|
||||
}
|
Loading…
Reference in New Issue
Block a user