feat: import midi file

This commit is contained in:
josc146 2023-12-10 22:38:31 +08:00
parent b5623cb9c2
commit 9b7b651ef9
5 changed files with 166 additions and 47 deletions

View File

@ -1,6 +1,6 @@
import io import io
import global_var import global_var
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, UploadFile, status
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from utils.midi import * from utils.midi import *
@ -33,6 +33,16 @@ def text_to_midi(body: TextToMidiBody):
return StreamingResponse(mid_data, media_type="audio/midi") return StreamingResponse(mid_data, media_type="audio/midi")
@router.post("/midi-to-text", tags=["MIDI"])
async def midi_to_text(file_data: UploadFile):
vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config)
mid = mido.MidiFile(file=file_data.file)
text = convert_midi_to_str(cfg, mid)
return {"text": text}
class TxtToMidiBody(BaseModel): class TxtToMidiBody(BaseModel):
txt_path: str txt_path: str
midi_path: str midi_path: str

View File

@ -311,5 +311,6 @@
"CN": "中国語", "CN": "中国語",
"JP": "日本語", "JP": "日本語",
"Music": "音楽", "Music": "音楽",
"Other": "その他" "Other": "その他",
"Import MIDI": "MIDIをインポート"
} }

View File

@ -311,5 +311,6 @@
"CN": "中文", "CN": "中文",
"JP": "日文", "JP": "日文",
"Music": "音乐", "Music": "音乐",
"Other": "其他" "Other": "其他",
"Import MIDI": "导入MIDI"
} }

View File

@ -7,6 +7,7 @@ import { v4 as uuid } from 'uuid';
import { import {
Add16Regular, Add16Regular,
ArrowAutofitWidth20Regular, ArrowAutofitWidth20Regular,
ArrowUpload16Regular,
Delete16Regular, Delete16Regular,
MusicNote220Regular, MusicNote220Regular,
Pause16Regular, Pause16Regular,
@ -20,6 +21,7 @@ import { useWindowSize } from 'usehooks-ts';
import commonStore from '../../stores/commonStore'; import commonStore from '../../stores/commonStore';
import classnames from 'classnames'; import classnames from 'classnames';
import { import {
InstrumentType,
InstrumentTypeNameMap, InstrumentTypeNameMap,
InstrumentTypeTokenMap, InstrumentTypeTokenMap,
MidiMessage, MidiMessage,
@ -27,8 +29,15 @@ import {
} from '../../types/composition'; } from '../../types/composition';
import { toast } from 'react-toastify'; import { toast } from 'react-toastify';
import { ToastOptions } from 'react-toastify/dist/types'; import { ToastOptions } from 'react-toastify/dist/types';
import { flushMidiRecordingContent, refreshTracksTotalTime } from '../../utils'; import {
import { PlayNote } from '../../../wailsjs/go/backend_golang/App'; absPathAsset,
flushMidiRecordingContent,
getMidiRawContentMainInstrument,
getMidiRawContentTime,
getServerRoot,
refreshTracksTotalTime
} from '../../utils';
import { OpenOpenFileDialog, PlayNote } from '../../../wailsjs/go/backend_golang/App';
import { t } from 'i18next'; import { t } from 'i18next';
const snapValue = 25; const snapValue = 25;
@ -47,14 +56,6 @@ const pixelFix = 0.5;
const topToArrowIcon = 19; const topToArrowIcon = 19;
const arrowIconToTracks = 23; const arrowIconToTracks = 23;
type TrackProps = {
id: string;
right: number;
scale: number;
isSelected: boolean;
onSelect: (id: string) => void;
};
const displayCurrentInstrumentType = () => { const displayCurrentInstrumentType = () => {
const displayPanelId = 'instrument_panel_id'; const displayPanelId = 'instrument_panel_id';
const content: React.ReactNode = const content: React.ReactNode =
@ -90,6 +91,53 @@ const velocityToBin = (velocity: number) => {
return Math.ceil((velocityEvents * ((Math.pow(velocityExp, (velocity / velocityEvents)) - 1.0) / (velocityExp - 1.0))) / binsize); return Math.ceil((velocityEvents * ((Math.pow(velocityExp, (velocity / velocityEvents)) - 1.0) / (velocityExp - 1.0))) / binsize);
}; };
const binToVelocity = (bin: number) => {
const binsize = velocityEvents / (velocityBins - 1);
return Math.max(0, Math.ceil(velocityEvents * (Math.log(((velocityExp - 1) * binsize * bin) / velocityEvents + 1) / Math.log(velocityExp)) - 1));
};
const tokenToMidiMessage = (token: string): MidiMessage | null => {
if (token.startsWith('<')) return null;
if (token.startsWith('t') && !token.startsWith('t:')) {
return {
messageType: 'ElapsedTime',
value: parseInt(token.substring(1)) * minimalMoveTime,
channel: 0,
note: 0,
velocity: 0,
control: 0,
instrument: 0
};
}
const instrument: InstrumentType = InstrumentTypeTokenMap.findIndex(t => token.startsWith(t + ':'));
if (instrument >= 0) {
const parts = token.split(':');
if (parts.length !== 3) return null;
const note = parseInt(parts[1], 16);
const velocity = parseInt(parts[2], 16);
if (velocity < 0 || velocity > 127) return null;
if (velocity === 0) return {
messageType: 'NoteOff',
note: note,
instrument: instrument,
channel: 0,
velocity: 0,
control: 0,
value: 0
};
return {
messageType: 'NoteOn',
note: note,
velocity: binToVelocity(velocity),
instrument: instrument,
channel: 0,
control: 0,
value: 0
} as MidiMessage;
}
return null;
};
const midiMessageToToken = (msg: MidiMessage) => { const midiMessageToToken = (msg: MidiMessage) => {
if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') { if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
const instrument = InstrumentTypeTokenMap[msg.instrument]; const instrument = InstrumentTypeTokenMap[msg.instrument];
@ -136,6 +184,14 @@ export const midiMessageHandler = async (data: MidiMessage) => {
} }
}; };
type TrackProps = {
id: string;
right: number;
scale: number;
isSelected: boolean;
onSelect: (id: string) => void;
};
const Track: React.FC<TrackProps> = observer(({ const Track: React.FC<TrackProps> = observer(({
id, id,
right, right,
@ -422,20 +478,66 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
</div> </div>
</div>)} </div>)}
<div className="flex justify-between items-center"> <div className="flex justify-between items-center">
<Button icon={<Add16Regular />} size="small" shape="circular" <div className="flex gap-1">
appearance="subtle" <Button icon={<Add16Regular />} size="small" shape="circular"
onClick={() => { appearance="subtle"
commonStore.setTracks([...commonStore.tracks, { onClick={() => {
id: uuid(), commonStore.setTracks([...commonStore.tracks, {
mainInstrument: '', id: uuid(),
content: '', mainInstrument: '',
rawContent: [], content: '',
offsetTime: 0, rawContent: [],
contentTime: 0 offsetTime: 0,
}]); contentTime: 0
}}> }]);
{t('New Track')} }}>
</Button> {t('New Track')}
</Button>
<Button icon={<ArrowUpload16Regular />} size="small" shape="circular"
appearance="subtle"
onClick={() => {
OpenOpenFileDialog('*.mid').then(async filePath => {
if (!filePath)
return;
const blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
const bodyForm = new FormData();
bodyForm.append('file_data', blob);
fetch(getServerRoot(commonStore.getCurrentModelConfig().apiParameters.apiPort) + '/midi-to-text', {
method: 'POST',
body: bodyForm
}).then(async r => {
if (r.status === 200) {
const text = (await r.json()).text as string;
const rawContent = text.split(' ').map(tokenToMidiMessage).filter(m => m) as MidiMessage[];
const tracks = commonStore.tracks.slice();
tracks.push({
id: uuid(),
mainInstrument: getMidiRawContentMainInstrument(rawContent),
content: text,
rawContent: rawContent,
offsetTime: 0,
contentTime: getMidiRawContentTime(rawContent)
});
commonStore.setTracks(tracks);
refreshTracksTotalTime();
} else {
toast(r.statusText + '\n' + (await r.text()), {
type: 'error'
});
}
}
).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}}>
{t('Import MIDI')}
</Button>
</div>
<Text size={100}> <Text size={100}>
{t('Select a track to preview the content')} {t('Select a track to preview the content')}
</Text> </Text>
@ -494,7 +596,7 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
} }
} }
} }
const result = ('<pad> ' + globalMessages.map(m => midiMessageToToken(m)).join('')).trim(); const result = ('<pad> ' + globalMessages.map(midiMessageToToken).join('')).trim();
commonStore.setCompositionSubmittedPrompt(result); commonStore.setCompositionSubmittedPrompt(result);
setPrompt(result); setPrompt(result);
}}> }}>

View File

@ -22,7 +22,7 @@ import { DownloadStatus } from '../types/downloads';
import { ModelSourceItem } from '../types/models'; import { ModelSourceItem } from '../types/models';
import { Language, Languages, SettingsType } from '../types/settings'; import { Language, Languages, SettingsType } from '../types/settings';
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train'; import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
import { InstrumentTypeNameMap, tracksMinimalTotalTime } from '../types/composition'; import { InstrumentTypeNameMap, MidiMessage, tracksMinimalTotalTime } from '../types/composition';
import logo from '../assets/images/logo.png'; import logo from '../assets/images/logo.png';
import { Preset } from '../types/presets'; import { Preset } from '../types/presets';
import { botName, Conversation, MessageType, userName } from '../types/chat'; import { botName, Conversation, MessageType, userName } from '../types/chat';
@ -513,34 +513,39 @@ export function refreshTracksTotalTime() {
commonStore.setTrackTotalTime(totalTime); commonStore.setTrackTotalTime(totalTime);
} }
export function getMidiRawContentTime(rawContent: MidiMessage[]) {
return rawContent.reduce((sum, current) =>
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
, 0);
}
export function getMidiRawContentMainInstrument(rawContent: MidiMessage[]) {
const sortedInstrumentFrequency = Object.entries(rawContent
.filter(c => c.messageType === 'NoteOn')
.map(c => c.instrument)
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
, {} as {
[key: string]: number
}))
.sort((a, b) => b[1] - a[1]);
let mainInstrument: string = '';
if (sortedInstrumentFrequency.length > 0)
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
return mainInstrument;
}
export function flushMidiRecordingContent() { export function flushMidiRecordingContent() {
const recordingTrackIndex = commonStore.tracks.findIndex(t => t.id === commonStore.recordingTrackId); const recordingTrackIndex = commonStore.tracks.findIndex(t => t.id === commonStore.recordingTrackId);
if (recordingTrackIndex >= 0) { if (recordingTrackIndex >= 0) {
const recordingTrack = commonStore.tracks[recordingTrackIndex]; const recordingTrack = commonStore.tracks[recordingTrackIndex];
const tracks = commonStore.tracks.slice(); const tracks = commonStore.tracks.slice();
const contentTime = commonStore.recordingRawContent
.reduce((sum, current) =>
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
, 0);
const sortedInstrumentFrequency = Object.entries(commonStore.recordingRawContent
.filter(c => c.messageType === 'NoteOn')
.map(c => c.instrument)
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
, {} as {
[key: string]: number
}))
.sort((a, b) => b[1] - a[1]);
let mainInstrument: string = '';
if (sortedInstrumentFrequency.length > 0)
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
tracks[recordingTrackIndex] = { tracks[recordingTrackIndex] = {
...recordingTrack, ...recordingTrack,
content: commonStore.recordingContent, content: commonStore.recordingContent,
rawContent: commonStore.recordingRawContent, rawContent: commonStore.recordingRawContent,
contentTime: contentTime, contentTime: getMidiRawContentTime(commonStore.recordingRawContent),
mainInstrument: mainInstrument mainInstrument: getMidiRawContentMainInstrument(commonStore.recordingRawContent)
}; };
commonStore.setTracks(tracks); commonStore.setTracks(tracks);
refreshTracksTotalTime(); refreshTracksTotalTime();