feat: import midi file
This commit is contained in:
parent
b5623cb9c2
commit
9b7b651ef9
@ -1,6 +1,6 @@
|
||||
import io
|
||||
import global_var
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, status
|
||||
from starlette.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from utils.midi import *
|
||||
@ -33,6 +33,16 @@ def text_to_midi(body: TextToMidiBody):
|
||||
return StreamingResponse(mid_data, media_type="audio/midi")
|
||||
|
||||
|
||||
@router.post("/midi-to-text", tags=["MIDI"])
|
||||
async def midi_to_text(file_data: UploadFile):
|
||||
vocab_config = "backend-python/utils/midi_vocab_config.json"
|
||||
cfg = VocabConfig.from_json(vocab_config)
|
||||
mid = mido.MidiFile(file=file_data.file)
|
||||
text = convert_midi_to_str(cfg, mid)
|
||||
|
||||
return {"text": text}
|
||||
|
||||
|
||||
class TxtToMidiBody(BaseModel):
|
||||
txt_path: str
|
||||
midi_path: str
|
||||
|
@ -311,5 +311,6 @@
|
||||
"CN": "中国語",
|
||||
"JP": "日本語",
|
||||
"Music": "音楽",
|
||||
"Other": "その他"
|
||||
"Other": "その他",
|
||||
"Import MIDI": "MIDIをインポート"
|
||||
}
|
@ -311,5 +311,6 @@
|
||||
"CN": "中文",
|
||||
"JP": "日文",
|
||||
"Music": "音乐",
|
||||
"Other": "其他"
|
||||
"Other": "其他",
|
||||
"Import MIDI": "导入MIDI"
|
||||
}
|
@ -7,6 +7,7 @@ import { v4 as uuid } from 'uuid';
|
||||
import {
|
||||
Add16Regular,
|
||||
ArrowAutofitWidth20Regular,
|
||||
ArrowUpload16Regular,
|
||||
Delete16Regular,
|
||||
MusicNote220Regular,
|
||||
Pause16Regular,
|
||||
@ -20,6 +21,7 @@ import { useWindowSize } from 'usehooks-ts';
|
||||
import commonStore from '../../stores/commonStore';
|
||||
import classnames from 'classnames';
|
||||
import {
|
||||
InstrumentType,
|
||||
InstrumentTypeNameMap,
|
||||
InstrumentTypeTokenMap,
|
||||
MidiMessage,
|
||||
@ -27,8 +29,15 @@ import {
|
||||
} from '../../types/composition';
|
||||
import { toast } from 'react-toastify';
|
||||
import { ToastOptions } from 'react-toastify/dist/types';
|
||||
import { flushMidiRecordingContent, refreshTracksTotalTime } from '../../utils';
|
||||
import { PlayNote } from '../../../wailsjs/go/backend_golang/App';
|
||||
import {
|
||||
absPathAsset,
|
||||
flushMidiRecordingContent,
|
||||
getMidiRawContentMainInstrument,
|
||||
getMidiRawContentTime,
|
||||
getServerRoot,
|
||||
refreshTracksTotalTime
|
||||
} from '../../utils';
|
||||
import { OpenOpenFileDialog, PlayNote } from '../../../wailsjs/go/backend_golang/App';
|
||||
import { t } from 'i18next';
|
||||
|
||||
const snapValue = 25;
|
||||
@ -47,14 +56,6 @@ const pixelFix = 0.5;
|
||||
const topToArrowIcon = 19;
|
||||
const arrowIconToTracks = 23;
|
||||
|
||||
type TrackProps = {
|
||||
id: string;
|
||||
right: number;
|
||||
scale: number;
|
||||
isSelected: boolean;
|
||||
onSelect: (id: string) => void;
|
||||
};
|
||||
|
||||
const displayCurrentInstrumentType = () => {
|
||||
const displayPanelId = 'instrument_panel_id';
|
||||
const content: React.ReactNode =
|
||||
@ -90,6 +91,53 @@ const velocityToBin = (velocity: number) => {
|
||||
return Math.ceil((velocityEvents * ((Math.pow(velocityExp, (velocity / velocityEvents)) - 1.0) / (velocityExp - 1.0))) / binsize);
|
||||
};
|
||||
|
||||
const binToVelocity = (bin: number) => {
|
||||
const binsize = velocityEvents / (velocityBins - 1);
|
||||
return Math.max(0, Math.ceil(velocityEvents * (Math.log(((velocityExp - 1) * binsize * bin) / velocityEvents + 1) / Math.log(velocityExp)) - 1));
|
||||
};
|
||||
|
||||
const tokenToMidiMessage = (token: string): MidiMessage | null => {
|
||||
if (token.startsWith('<')) return null;
|
||||
if (token.startsWith('t') && !token.startsWith('t:')) {
|
||||
return {
|
||||
messageType: 'ElapsedTime',
|
||||
value: parseInt(token.substring(1)) * minimalMoveTime,
|
||||
channel: 0,
|
||||
note: 0,
|
||||
velocity: 0,
|
||||
control: 0,
|
||||
instrument: 0
|
||||
};
|
||||
}
|
||||
const instrument: InstrumentType = InstrumentTypeTokenMap.findIndex(t => token.startsWith(t + ':'));
|
||||
if (instrument >= 0) {
|
||||
const parts = token.split(':');
|
||||
if (parts.length !== 3) return null;
|
||||
const note = parseInt(parts[1], 16);
|
||||
const velocity = parseInt(parts[2], 16);
|
||||
if (velocity < 0 || velocity > 127) return null;
|
||||
if (velocity === 0) return {
|
||||
messageType: 'NoteOff',
|
||||
note: note,
|
||||
instrument: instrument,
|
||||
channel: 0,
|
||||
velocity: 0,
|
||||
control: 0,
|
||||
value: 0
|
||||
};
|
||||
return {
|
||||
messageType: 'NoteOn',
|
||||
note: note,
|
||||
velocity: binToVelocity(velocity),
|
||||
instrument: instrument,
|
||||
channel: 0,
|
||||
control: 0,
|
||||
value: 0
|
||||
} as MidiMessage;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const midiMessageToToken = (msg: MidiMessage) => {
|
||||
if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
|
||||
const instrument = InstrumentTypeTokenMap[msg.instrument];
|
||||
@ -136,6 +184,14 @@ export const midiMessageHandler = async (data: MidiMessage) => {
|
||||
}
|
||||
};
|
||||
|
||||
type TrackProps = {
|
||||
id: string;
|
||||
right: number;
|
||||
scale: number;
|
||||
isSelected: boolean;
|
||||
onSelect: (id: string) => void;
|
||||
};
|
||||
|
||||
const Track: React.FC<TrackProps> = observer(({
|
||||
id,
|
||||
right,
|
||||
@ -422,20 +478,66 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
</div>
|
||||
</div>)}
|
||||
<div className="flex justify-between items-center">
|
||||
<Button icon={<Add16Regular />} size="small" shape="circular"
|
||||
appearance="subtle"
|
||||
onClick={() => {
|
||||
commonStore.setTracks([...commonStore.tracks, {
|
||||
id: uuid(),
|
||||
mainInstrument: '',
|
||||
content: '',
|
||||
rawContent: [],
|
||||
offsetTime: 0,
|
||||
contentTime: 0
|
||||
}]);
|
||||
}}>
|
||||
{t('New Track')}
|
||||
</Button>
|
||||
<div className="flex gap-1">
|
||||
<Button icon={<Add16Regular />} size="small" shape="circular"
|
||||
appearance="subtle"
|
||||
onClick={() => {
|
||||
commonStore.setTracks([...commonStore.tracks, {
|
||||
id: uuid(),
|
||||
mainInstrument: '',
|
||||
content: '',
|
||||
rawContent: [],
|
||||
offsetTime: 0,
|
||||
contentTime: 0
|
||||
}]);
|
||||
}}>
|
||||
{t('New Track')}
|
||||
</Button>
|
||||
<Button icon={<ArrowUpload16Regular />} size="small" shape="circular"
|
||||
appearance="subtle"
|
||||
onClick={() => {
|
||||
OpenOpenFileDialog('*.mid').then(async filePath => {
|
||||
if (!filePath)
|
||||
return;
|
||||
|
||||
const blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
|
||||
const bodyForm = new FormData();
|
||||
bodyForm.append('file_data', blob);
|
||||
fetch(getServerRoot(commonStore.getCurrentModelConfig().apiParameters.apiPort) + '/midi-to-text', {
|
||||
method: 'POST',
|
||||
body: bodyForm
|
||||
}).then(async r => {
|
||||
if (r.status === 200) {
|
||||
const text = (await r.json()).text as string;
|
||||
const rawContent = text.split(' ').map(tokenToMidiMessage).filter(m => m) as MidiMessage[];
|
||||
const tracks = commonStore.tracks.slice();
|
||||
|
||||
tracks.push({
|
||||
id: uuid(),
|
||||
mainInstrument: getMidiRawContentMainInstrument(rawContent),
|
||||
content: text,
|
||||
rawContent: rawContent,
|
||||
offsetTime: 0,
|
||||
contentTime: getMidiRawContentTime(rawContent)
|
||||
});
|
||||
commonStore.setTracks(tracks);
|
||||
refreshTracksTotalTime();
|
||||
} else {
|
||||
toast(r.statusText + '\n' + (await r.text()), {
|
||||
type: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
).catch(e => {
|
||||
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
|
||||
});
|
||||
}).catch(e => {
|
||||
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
|
||||
});
|
||||
}}>
|
||||
{t('Import MIDI')}
|
||||
</Button>
|
||||
</div>
|
||||
<Text size={100}>
|
||||
{t('Select a track to preview the content')}
|
||||
</Text>
|
||||
@ -494,7 +596,7 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
}
|
||||
}
|
||||
}
|
||||
const result = ('<pad> ' + globalMessages.map(m => midiMessageToToken(m)).join('')).trim();
|
||||
const result = ('<pad> ' + globalMessages.map(midiMessageToToken).join('')).trim();
|
||||
commonStore.setCompositionSubmittedPrompt(result);
|
||||
setPrompt(result);
|
||||
}}>
|
||||
|
@ -22,7 +22,7 @@ import { DownloadStatus } from '../types/downloads';
|
||||
import { ModelSourceItem } from '../types/models';
|
||||
import { Language, Languages, SettingsType } from '../types/settings';
|
||||
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
|
||||
import { InstrumentTypeNameMap, tracksMinimalTotalTime } from '../types/composition';
|
||||
import { InstrumentTypeNameMap, MidiMessage, tracksMinimalTotalTime } from '../types/composition';
|
||||
import logo from '../assets/images/logo.png';
|
||||
import { Preset } from '../types/presets';
|
||||
import { botName, Conversation, MessageType, userName } from '../types/chat';
|
||||
@ -513,34 +513,39 @@ export function refreshTracksTotalTime() {
|
||||
commonStore.setTrackTotalTime(totalTime);
|
||||
}
|
||||
|
||||
export function getMidiRawContentTime(rawContent: MidiMessage[]) {
|
||||
return rawContent.reduce((sum, current) =>
|
||||
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
|
||||
, 0);
|
||||
}
|
||||
|
||||
export function getMidiRawContentMainInstrument(rawContent: MidiMessage[]) {
|
||||
const sortedInstrumentFrequency = Object.entries(rawContent
|
||||
.filter(c => c.messageType === 'NoteOn')
|
||||
.map(c => c.instrument)
|
||||
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
|
||||
, {} as {
|
||||
[key: string]: number
|
||||
}))
|
||||
.sort((a, b) => b[1] - a[1]);
|
||||
let mainInstrument: string = '';
|
||||
if (sortedInstrumentFrequency.length > 0)
|
||||
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
|
||||
return mainInstrument;
|
||||
}
|
||||
|
||||
export function flushMidiRecordingContent() {
|
||||
const recordingTrackIndex = commonStore.tracks.findIndex(t => t.id === commonStore.recordingTrackId);
|
||||
if (recordingTrackIndex >= 0) {
|
||||
const recordingTrack = commonStore.tracks[recordingTrackIndex];
|
||||
const tracks = commonStore.tracks.slice();
|
||||
const contentTime = commonStore.recordingRawContent
|
||||
.reduce((sum, current) =>
|
||||
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
|
||||
, 0);
|
||||
|
||||
const sortedInstrumentFrequency = Object.entries(commonStore.recordingRawContent
|
||||
.filter(c => c.messageType === 'NoteOn')
|
||||
.map(c => c.instrument)
|
||||
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
|
||||
, {} as {
|
||||
[key: string]: number
|
||||
}))
|
||||
.sort((a, b) => b[1] - a[1]);
|
||||
let mainInstrument: string = '';
|
||||
if (sortedInstrumentFrequency.length > 0)
|
||||
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
|
||||
|
||||
tracks[recordingTrackIndex] = {
|
||||
...recordingTrack,
|
||||
content: commonStore.recordingContent,
|
||||
rawContent: commonStore.recordingRawContent,
|
||||
contentTime: contentTime,
|
||||
mainInstrument: mainInstrument
|
||||
contentTime: getMidiRawContentTime(commonStore.recordingRawContent),
|
||||
mainInstrument: getMidiRawContentMainInstrument(commonStore.recordingRawContent)
|
||||
};
|
||||
commonStore.setTracks(tracks);
|
||||
refreshTracksTotalTime();
|
||||
|
Loading…
Reference in New Issue
Block a user