bump MIDI-LLM-tokenizer (fix note off)
This commit is contained in:
		
							parent
							
								
									f328e84ea7
								
							
						
					
					
						commit
						e0bf44d82f
					
				@ -37,10 +37,14 @@ def text_to_midi(body: TextToMidiBody):
 | 
			
		||||
async def midi_to_text(file_data: UploadFile):
 | 
			
		||||
    vocab_config = "backend-python/utils/midi_vocab_config.json"
 | 
			
		||||
    cfg = VocabConfig.from_json(vocab_config)
 | 
			
		||||
    filter_config = "backend-python/utils/midi_filter_config.json"
 | 
			
		||||
    filter_cfg = FilterConfig.from_json(filter_config)
 | 
			
		||||
    mid = mido.MidiFile(file=file_data.file)
 | 
			
		||||
    text = convert_midi_to_str(cfg, mid)
 | 
			
		||||
    output_list = convert_midi_to_str(cfg, filter_cfg, mid)
 | 
			
		||||
    if len(output_list) == 0:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad midi file")
 | 
			
		||||
 | 
			
		||||
    return {"text": text}
 | 
			
		||||
    return {"text": output_list[0]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TxtToMidiBody(BaseModel):
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										71
									
								
								backend-python/utils/midi.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										71
									
								
								backend-python/utils/midi.py
									
									
									
									
										vendored
									
									
								
							@ -52,6 +52,8 @@ class VocabConfig:
 | 
			
		||||
    bin_name_to_program_name: Dict[str, str]
 | 
			
		||||
    # Mapping from program number to instrument name.
 | 
			
		||||
    instrument_names: Dict[str, str]
 | 
			
		||||
    # Manual override for velocity bins. Each element is the max velocity value for that bin by index.
 | 
			
		||||
    velocity_bins_override: Optional[List[int]] = None
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.validate()
 | 
			
		||||
@ -116,6 +118,12 @@ class VocabConfig:
 | 
			
		||||
            raise ValueError("velocity_bins must be at least 2")
 | 
			
		||||
        if len(self.bin_instrument_names) > 16:
 | 
			
		||||
            raise ValueError("bin_instruments must have at most 16 values")
 | 
			
		||||
        if self.velocity_bins_override:
 | 
			
		||||
            print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
 | 
			
		||||
            if len(self.velocity_bins_override) != self.velocity_bins:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "velocity_bins_override must have same length as velocity_bins"
 | 
			
		||||
                )
 | 
			
		||||
        if (
 | 
			
		||||
            self.ch10_instrument_bin_name
 | 
			
		||||
            and self.ch10_instrument_bin_name not in self.bin_instrument_names
 | 
			
		||||
@ -156,6 +164,11 @@ class VocabUtils:
 | 
			
		||||
 | 
			
		||||
    def velocity_to_bin(self, velocity: float) -> int:
 | 
			
		||||
        velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
 | 
			
		||||
        if self.cfg.velocity_bins_override:
 | 
			
		||||
            for i, v in enumerate(self.cfg.velocity_bins_override):
 | 
			
		||||
                if velocity <= v:
 | 
			
		||||
                    return i
 | 
			
		||||
            return 0
 | 
			
		||||
        binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
 | 
			
		||||
        if self.cfg.velocity_exp == 1.0:
 | 
			
		||||
            return ceil(velocity / binsize)
 | 
			
		||||
@ -176,6 +189,8 @@ class VocabUtils:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def bin_to_velocity(self, bin: int) -> int:
 | 
			
		||||
        if self.cfg.velocity_bins_override:
 | 
			
		||||
            return self.cfg.velocity_bins_override[bin]
 | 
			
		||||
        binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
 | 
			
		||||
        if self.cfg.velocity_exp == 1.0:
 | 
			
		||||
            return max(0, ceil(bin * binsize - 1))
 | 
			
		||||
@ -358,13 +373,32 @@ class AugmentConfig:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FilterConfig:
 | 
			
		||||
    # Whether to filter out MIDI files with duplicate MD5 hashes.
 | 
			
		||||
    deduplicate_md5: bool
 | 
			
		||||
    # Minimum time delay between notes in a file before splitting into multiple documents.
 | 
			
		||||
    piece_split_delay: float
 | 
			
		||||
    # Minimum length of a piece in milliseconds.
 | 
			
		||||
    min_piece_length: float
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_json(cls, path: str):
 | 
			
		||||
        with open(path, "r") as f:
 | 
			
		||||
            config = json.load(f)
 | 
			
		||||
        return cls(**config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mix_volume(velocity: int, volume: int, expression: int) -> float:
 | 
			
		||||
    return velocity * (volume / 127.0) * (expression / 127.0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_midi_to_str(
 | 
			
		||||
    cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None
 | 
			
		||||
) -> str:
 | 
			
		||||
    cfg: VocabConfig,
 | 
			
		||||
    filter_cfg: FilterConfig,
 | 
			
		||||
    mid: mido.MidiFile,
 | 
			
		||||
    augment: AugmentValues = None,
 | 
			
		||||
) -> List[str]:
 | 
			
		||||
    utils = VocabUtils(cfg)
 | 
			
		||||
    if augment is None:
 | 
			
		||||
        augment = AugmentValues.default()
 | 
			
		||||
@ -390,7 +424,9 @@ def convert_midi_to_str(
 | 
			
		||||
    }  # {channel: {(note, program) -> True}}
 | 
			
		||||
    started_flag = False
 | 
			
		||||
 | 
			
		||||
    output_list = []
 | 
			
		||||
    output = ["<start>"]
 | 
			
		||||
    output_length_ms = 0.0
 | 
			
		||||
    token_data_buffer: List[
 | 
			
		||||
        Tuple[int, int, int, float]
 | 
			
		||||
    ] = []  # need to sort notes between wait tokens
 | 
			
		||||
@ -432,16 +468,33 @@ def convert_midi_to_str(
 | 
			
		||||
        token_data_buffer = []
 | 
			
		||||
 | 
			
		||||
    def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
 | 
			
		||||
        nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer
 | 
			
		||||
        nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
 | 
			
		||||
        is_token_valid = (
 | 
			
		||||
            utils.prog_data_to_token_data(prog, chan, note, vel) is not None
 | 
			
		||||
        )
 | 
			
		||||
        if not is_token_valid:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
 | 
			
		||||
            # check if any notes are still held
 | 
			
		||||
            silent = True
 | 
			
		||||
            for channel in channel_notes.keys():
 | 
			
		||||
                if len(channel_notes[channel]) > 0:
 | 
			
		||||
                    silent = False
 | 
			
		||||
                    break
 | 
			
		||||
            if silent:
 | 
			
		||||
                flush_token_data_buffer()
 | 
			
		||||
                output.append("<end>")
 | 
			
		||||
                if output_length_ms > filter_cfg.min_piece_length * 1000.0:
 | 
			
		||||
                    output_list.append(" ".join(output))
 | 
			
		||||
                output = ["<start>"]
 | 
			
		||||
                output_length_ms = 0.0
 | 
			
		||||
                started_flag = False
 | 
			
		||||
        if started_flag:
 | 
			
		||||
            wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
 | 
			
		||||
            if len(wait_tokens) > 0:
 | 
			
		||||
                flush_token_data_buffer()
 | 
			
		||||
                output_length_ms += delta_time_ms
 | 
			
		||||
                output += wait_tokens
 | 
			
		||||
        delta_time_ms = 0.0
 | 
			
		||||
        token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
 | 
			
		||||
@ -510,7 +563,9 @@ def convert_midi_to_str(
 | 
			
		||||
 | 
			
		||||
    flush_token_data_buffer()
 | 
			
		||||
    output.append("<end>")
 | 
			
		||||
    return " ".join(output)
 | 
			
		||||
    if output_length_ms > filter_cfg.min_piece_length * 1000.0:
 | 
			
		||||
        output_list.append(" ".join(output))
 | 
			
		||||
    return output_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_program_change_messages(cfg: VocabConfig):
 | 
			
		||||
@ -633,10 +688,10 @@ def token_to_midi_message(
 | 
			
		||||
                if utils.cfg.decode_fix_repeated_notes:
 | 
			
		||||
                    if (channel, note) in state.active_notes:
 | 
			
		||||
                        del state.active_notes[(channel, note)]
 | 
			
		||||
                    yield mido.Message(
 | 
			
		||||
                        "note_off", note=note, time=ticks, channel=channel
 | 
			
		||||
                    ), state
 | 
			
		||||
                    ticks = 0
 | 
			
		||||
                        yield mido.Message(
 | 
			
		||||
                            "note_off", note=note, time=ticks, channel=channel
 | 
			
		||||
                        ), state
 | 
			
		||||
                        ticks = 0
 | 
			
		||||
                state.active_notes[(channel, note)] = state.total_time
 | 
			
		||||
                yield mido.Message(
 | 
			
		||||
                    "note_on", note=note, velocity=velocity, time=ticks, channel=channel
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								backend-python/utils/midi_filter_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								backend-python/utils/midi_filter_config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
{
 | 
			
		||||
    "deduplicate_md5": true,
 | 
			
		||||
    "piece_split_delay": 10.0,
 | 
			
		||||
    "min_piece_length": 30.0
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user