bump MIDI-LLM-tokenizer (fix note off)
This commit is contained in:
parent
f328e84ea7
commit
e0bf44d82f
@ -37,10 +37,14 @@ def text_to_midi(body: TextToMidiBody):
|
|||||||
async def midi_to_text(file_data: UploadFile):
|
async def midi_to_text(file_data: UploadFile):
|
||||||
vocab_config = "backend-python/utils/midi_vocab_config.json"
|
vocab_config = "backend-python/utils/midi_vocab_config.json"
|
||||||
cfg = VocabConfig.from_json(vocab_config)
|
cfg = VocabConfig.from_json(vocab_config)
|
||||||
|
filter_config = "backend-python/utils/midi_filter_config.json"
|
||||||
|
filter_cfg = FilterConfig.from_json(filter_config)
|
||||||
mid = mido.MidiFile(file=file_data.file)
|
mid = mido.MidiFile(file=file_data.file)
|
||||||
text = convert_midi_to_str(cfg, mid)
|
output_list = convert_midi_to_str(cfg, filter_cfg, mid)
|
||||||
|
if len(output_list) == 0:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad midi file")
|
||||||
|
|
||||||
return {"text": text}
|
return {"text": output_list[0]}
|
||||||
|
|
||||||
|
|
||||||
class TxtToMidiBody(BaseModel):
|
class TxtToMidiBody(BaseModel):
|
||||||
|
63
backend-python/utils/midi.py
vendored
63
backend-python/utils/midi.py
vendored
@ -52,6 +52,8 @@ class VocabConfig:
|
|||||||
bin_name_to_program_name: Dict[str, str]
|
bin_name_to_program_name: Dict[str, str]
|
||||||
# Mapping from program number to instrument name.
|
# Mapping from program number to instrument name.
|
||||||
instrument_names: Dict[str, str]
|
instrument_names: Dict[str, str]
|
||||||
|
# Manual override for velocity bins. Each element is the max velocity value for that bin by index.
|
||||||
|
velocity_bins_override: Optional[List[int]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.validate()
|
self.validate()
|
||||||
@ -116,6 +118,12 @@ class VocabConfig:
|
|||||||
raise ValueError("velocity_bins must be at least 2")
|
raise ValueError("velocity_bins must be at least 2")
|
||||||
if len(self.bin_instrument_names) > 16:
|
if len(self.bin_instrument_names) > 16:
|
||||||
raise ValueError("bin_instruments must have at most 16 values")
|
raise ValueError("bin_instruments must have at most 16 values")
|
||||||
|
if self.velocity_bins_override:
|
||||||
|
print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
|
||||||
|
if len(self.velocity_bins_override) != self.velocity_bins:
|
||||||
|
raise ValueError(
|
||||||
|
"velocity_bins_override must have same length as velocity_bins"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
self.ch10_instrument_bin_name
|
self.ch10_instrument_bin_name
|
||||||
and self.ch10_instrument_bin_name not in self.bin_instrument_names
|
and self.ch10_instrument_bin_name not in self.bin_instrument_names
|
||||||
@ -156,6 +164,11 @@ class VocabUtils:
|
|||||||
|
|
||||||
def velocity_to_bin(self, velocity: float) -> int:
|
def velocity_to_bin(self, velocity: float) -> int:
|
||||||
velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
|
velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
|
||||||
|
if self.cfg.velocity_bins_override:
|
||||||
|
for i, v in enumerate(self.cfg.velocity_bins_override):
|
||||||
|
if velocity <= v:
|
||||||
|
return i
|
||||||
|
return 0
|
||||||
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
||||||
if self.cfg.velocity_exp == 1.0:
|
if self.cfg.velocity_exp == 1.0:
|
||||||
return ceil(velocity / binsize)
|
return ceil(velocity / binsize)
|
||||||
@ -176,6 +189,8 @@ class VocabUtils:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def bin_to_velocity(self, bin: int) -> int:
|
def bin_to_velocity(self, bin: int) -> int:
|
||||||
|
if self.cfg.velocity_bins_override:
|
||||||
|
return self.cfg.velocity_bins_override[bin]
|
||||||
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
||||||
if self.cfg.velocity_exp == 1.0:
|
if self.cfg.velocity_exp == 1.0:
|
||||||
return max(0, ceil(bin * binsize - 1))
|
return max(0, ceil(bin * binsize - 1))
|
||||||
@ -358,13 +373,32 @@ class AugmentConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilterConfig:
|
||||||
|
# Whether to filter out MIDI files with duplicate MD5 hashes.
|
||||||
|
deduplicate_md5: bool
|
||||||
|
# Minimum time delay between notes in a file before splitting into multiple documents.
|
||||||
|
piece_split_delay: float
|
||||||
|
# Minimum length of a piece in milliseconds.
|
||||||
|
min_piece_length: float
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, path: str):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
|
||||||
def mix_volume(velocity: int, volume: int, expression: int) -> float:
|
def mix_volume(velocity: int, volume: int, expression: int) -> float:
|
||||||
return velocity * (volume / 127.0) * (expression / 127.0)
|
return velocity * (volume / 127.0) * (expression / 127.0)
|
||||||
|
|
||||||
|
|
||||||
def convert_midi_to_str(
|
def convert_midi_to_str(
|
||||||
cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None
|
cfg: VocabConfig,
|
||||||
) -> str:
|
filter_cfg: FilterConfig,
|
||||||
|
mid: mido.MidiFile,
|
||||||
|
augment: AugmentValues = None,
|
||||||
|
) -> List[str]:
|
||||||
utils = VocabUtils(cfg)
|
utils = VocabUtils(cfg)
|
||||||
if augment is None:
|
if augment is None:
|
||||||
augment = AugmentValues.default()
|
augment = AugmentValues.default()
|
||||||
@ -390,7 +424,9 @@ def convert_midi_to_str(
|
|||||||
} # {channel: {(note, program) -> True}}
|
} # {channel: {(note, program) -> True}}
|
||||||
started_flag = False
|
started_flag = False
|
||||||
|
|
||||||
|
output_list = []
|
||||||
output = ["<start>"]
|
output = ["<start>"]
|
||||||
|
output_length_ms = 0.0
|
||||||
token_data_buffer: List[
|
token_data_buffer: List[
|
||||||
Tuple[int, int, int, float]
|
Tuple[int, int, int, float]
|
||||||
] = [] # need to sort notes between wait tokens
|
] = [] # need to sort notes between wait tokens
|
||||||
@ -432,16 +468,33 @@ def convert_midi_to_str(
|
|||||||
token_data_buffer = []
|
token_data_buffer = []
|
||||||
|
|
||||||
def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
|
def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
|
||||||
nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer
|
nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
|
||||||
is_token_valid = (
|
is_token_valid = (
|
||||||
utils.prog_data_to_token_data(prog, chan, note, vel) is not None
|
utils.prog_data_to_token_data(prog, chan, note, vel) is not None
|
||||||
)
|
)
|
||||||
if not is_token_valid:
|
if not is_token_valid:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
|
||||||
|
# check if any notes are still held
|
||||||
|
silent = True
|
||||||
|
for channel in channel_notes.keys():
|
||||||
|
if len(channel_notes[channel]) > 0:
|
||||||
|
silent = False
|
||||||
|
break
|
||||||
|
if silent:
|
||||||
|
flush_token_data_buffer()
|
||||||
|
output.append("<end>")
|
||||||
|
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
|
||||||
|
output_list.append(" ".join(output))
|
||||||
|
output = ["<start>"]
|
||||||
|
output_length_ms = 0.0
|
||||||
|
started_flag = False
|
||||||
if started_flag:
|
if started_flag:
|
||||||
wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
|
wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
|
||||||
if len(wait_tokens) > 0:
|
if len(wait_tokens) > 0:
|
||||||
flush_token_data_buffer()
|
flush_token_data_buffer()
|
||||||
|
output_length_ms += delta_time_ms
|
||||||
output += wait_tokens
|
output += wait_tokens
|
||||||
delta_time_ms = 0.0
|
delta_time_ms = 0.0
|
||||||
token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
|
token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
|
||||||
@ -510,7 +563,9 @@ def convert_midi_to_str(
|
|||||||
|
|
||||||
flush_token_data_buffer()
|
flush_token_data_buffer()
|
||||||
output.append("<end>")
|
output.append("<end>")
|
||||||
return " ".join(output)
|
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
|
||||||
|
output_list.append(" ".join(output))
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
|
||||||
def generate_program_change_messages(cfg: VocabConfig):
|
def generate_program_change_messages(cfg: VocabConfig):
|
||||||
|
5
backend-python/utils/midi_filter_config.json
Normal file
5
backend-python/utils/midi_filter_config.json
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"deduplicate_md5": true,
|
||||||
|
"piece_split_delay": 10.0,
|
||||||
|
"min_piece_length": 30.0
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user