feat: import midi file

This commit is contained in:
josc146 2023-12-10 22:38:31 +08:00
parent b5623cb9c2
commit 9b7b651ef9
5 changed files with 166 additions and 47 deletions

View File

@ -1,6 +1,6 @@
import io
import global_var
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, HTTPException, UploadFile, status
from starlette.responses import StreamingResponse
from pydantic import BaseModel
from utils.midi import *
@ -33,6 +33,16 @@ def text_to_midi(body: TextToMidiBody):
return StreamingResponse(mid_data, media_type="audio/midi")
@router.post("/midi-to-text", tags=["MIDI"])
async def midi_to_text(file_data: UploadFile):
vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config)
mid = mido.MidiFile(file=file_data.file)
text = convert_midi_to_str(cfg, mid)
return {"text": text}
class TxtToMidiBody(BaseModel):
txt_path: str
midi_path: str

View File

@ -311,5 +311,6 @@
"CN": "中国語",
"JP": "日本語",
"Music": "音楽",
"Other": "その他"
"Other": "その他",
"Import MIDI": "MIDIをインポート"
}

View File

@ -311,5 +311,6 @@
"CN": "中文",
"JP": "日文",
"Music": "音乐",
"Other": "其他"
"Other": "其他",
"Import MIDI": "导入MIDI"
}

View File

@ -7,6 +7,7 @@ import { v4 as uuid } from 'uuid';
import {
Add16Regular,
ArrowAutofitWidth20Regular,
ArrowUpload16Regular,
Delete16Regular,
MusicNote220Regular,
Pause16Regular,
@ -20,6 +21,7 @@ import { useWindowSize } from 'usehooks-ts';
import commonStore from '../../stores/commonStore';
import classnames from 'classnames';
import {
InstrumentType,
InstrumentTypeNameMap,
InstrumentTypeTokenMap,
MidiMessage,
@ -27,8 +29,15 @@ import {
} from '../../types/composition';
import { toast } from 'react-toastify';
import { ToastOptions } from 'react-toastify/dist/types';
import { flushMidiRecordingContent, refreshTracksTotalTime } from '../../utils';
import { PlayNote } from '../../../wailsjs/go/backend_golang/App';
import {
absPathAsset,
flushMidiRecordingContent,
getMidiRawContentMainInstrument,
getMidiRawContentTime,
getServerRoot,
refreshTracksTotalTime
} from '../../utils';
import { OpenOpenFileDialog, PlayNote } from '../../../wailsjs/go/backend_golang/App';
import { t } from 'i18next';
const snapValue = 25;
@ -47,14 +56,6 @@ const pixelFix = 0.5;
const topToArrowIcon = 19;
const arrowIconToTracks = 23;
type TrackProps = {
id: string;
right: number;
scale: number;
isSelected: boolean;
onSelect: (id: string) => void;
};
const displayCurrentInstrumentType = () => {
const displayPanelId = 'instrument_panel_id';
const content: React.ReactNode =
@ -90,6 +91,53 @@ const velocityToBin = (velocity: number) => {
return Math.ceil((velocityEvents * ((Math.pow(velocityExp, (velocity / velocityEvents)) - 1.0) / (velocityExp - 1.0))) / binsize);
};
const binToVelocity = (bin: number) => {
const binsize = velocityEvents / (velocityBins - 1);
return Math.max(0, Math.ceil(velocityEvents * (Math.log(((velocityExp - 1) * binsize * bin) / velocityEvents + 1) / Math.log(velocityExp)) - 1));
};
const tokenToMidiMessage = (token: string): MidiMessage | null => {
if (token.startsWith('<')) return null;
if (token.startsWith('t') && !token.startsWith('t:')) {
return {
messageType: 'ElapsedTime',
value: parseInt(token.substring(1)) * minimalMoveTime,
channel: 0,
note: 0,
velocity: 0,
control: 0,
instrument: 0
};
}
const instrument: InstrumentType = InstrumentTypeTokenMap.findIndex(t => token.startsWith(t + ':'));
if (instrument >= 0) {
const parts = token.split(':');
if (parts.length !== 3) return null;
const note = parseInt(parts[1], 16);
const velocity = parseInt(parts[2], 16);
if (velocity < 0 || velocity > 127) return null;
if (velocity === 0) return {
messageType: 'NoteOff',
note: note,
instrument: instrument,
channel: 0,
velocity: 0,
control: 0,
value: 0
};
return {
messageType: 'NoteOn',
note: note,
velocity: binToVelocity(velocity),
instrument: instrument,
channel: 0,
control: 0,
value: 0
} as MidiMessage;
}
return null;
};
const midiMessageToToken = (msg: MidiMessage) => {
if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
const instrument = InstrumentTypeTokenMap[msg.instrument];
@ -136,6 +184,14 @@ export const midiMessageHandler = async (data: MidiMessage) => {
}
};
type TrackProps = {
id: string;
right: number;
scale: number;
isSelected: boolean;
onSelect: (id: string) => void;
};
const Track: React.FC<TrackProps> = observer(({
id,
right,
@ -422,20 +478,66 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
</div>
</div>)}
<div className="flex justify-between items-center">
<Button icon={<Add16Regular />} size="small" shape="circular"
appearance="subtle"
onClick={() => {
commonStore.setTracks([...commonStore.tracks, {
id: uuid(),
mainInstrument: '',
content: '',
rawContent: [],
offsetTime: 0,
contentTime: 0
}]);
}}>
{t('New Track')}
</Button>
<div className="flex gap-1">
<Button icon={<Add16Regular />} size="small" shape="circular"
appearance="subtle"
onClick={() => {
commonStore.setTracks([...commonStore.tracks, {
id: uuid(),
mainInstrument: '',
content: '',
rawContent: [],
offsetTime: 0,
contentTime: 0
}]);
}}>
{t('New Track')}
</Button>
<Button icon={<ArrowUpload16Regular />} size="small" shape="circular"
appearance="subtle"
onClick={() => {
OpenOpenFileDialog('*.mid').then(async filePath => {
if (!filePath)
return;
const blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
const bodyForm = new FormData();
bodyForm.append('file_data', blob);
fetch(getServerRoot(commonStore.getCurrentModelConfig().apiParameters.apiPort) + '/midi-to-text', {
method: 'POST',
body: bodyForm
}).then(async r => {
if (r.status === 200) {
const text = (await r.json()).text as string;
const rawContent = text.split(' ').map(tokenToMidiMessage).filter(m => m) as MidiMessage[];
const tracks = commonStore.tracks.slice();
tracks.push({
id: uuid(),
mainInstrument: getMidiRawContentMainInstrument(rawContent),
content: text,
rawContent: rawContent,
offsetTime: 0,
contentTime: getMidiRawContentTime(rawContent)
});
commonStore.setTracks(tracks);
refreshTracksTotalTime();
} else {
toast(r.statusText + '\n' + (await r.text()), {
type: 'error'
});
}
}
).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}}>
{t('Import MIDI')}
</Button>
</div>
<Text size={100}>
{t('Select a track to preview the content')}
</Text>
@ -494,7 +596,7 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
}
}
}
const result = ('<pad> ' + globalMessages.map(m => midiMessageToToken(m)).join('')).trim();
const result = ('<pad> ' + globalMessages.map(midiMessageToToken).join('')).trim();
commonStore.setCompositionSubmittedPrompt(result);
setPrompt(result);
}}>

View File

@ -22,7 +22,7 @@ import { DownloadStatus } from '../types/downloads';
import { ModelSourceItem } from '../types/models';
import { Language, Languages, SettingsType } from '../types/settings';
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
import { InstrumentTypeNameMap, tracksMinimalTotalTime } from '../types/composition';
import { InstrumentTypeNameMap, MidiMessage, tracksMinimalTotalTime } from '../types/composition';
import logo from '../assets/images/logo.png';
import { Preset } from '../types/presets';
import { botName, Conversation, MessageType, userName } from '../types/chat';
@ -513,34 +513,39 @@ export function refreshTracksTotalTime() {
commonStore.setTrackTotalTime(totalTime);
}
export function getMidiRawContentTime(rawContent: MidiMessage[]) {
return rawContent.reduce((sum, current) =>
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
, 0);
}
export function getMidiRawContentMainInstrument(rawContent: MidiMessage[]) {
const sortedInstrumentFrequency = Object.entries(rawContent
.filter(c => c.messageType === 'NoteOn')
.map(c => c.instrument)
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
, {} as {
[key: string]: number
}))
.sort((a, b) => b[1] - a[1]);
let mainInstrument: string = '';
if (sortedInstrumentFrequency.length > 0)
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
return mainInstrument;
}
export function flushMidiRecordingContent() {
const recordingTrackIndex = commonStore.tracks.findIndex(t => t.id === commonStore.recordingTrackId);
if (recordingTrackIndex >= 0) {
const recordingTrack = commonStore.tracks[recordingTrackIndex];
const tracks = commonStore.tracks.slice();
const contentTime = commonStore.recordingRawContent
.reduce((sum, current) =>
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
, 0);
const sortedInstrumentFrequency = Object.entries(commonStore.recordingRawContent
.filter(c => c.messageType === 'NoteOn')
.map(c => c.instrument)
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
, {} as {
[key: string]: number
}))
.sort((a, b) => b[1] - a[1]);
let mainInstrument: string = '';
if (sortedInstrumentFrequency.length > 0)
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
tracks[recordingTrackIndex] = {
...recordingTrack,
content: commonStore.recordingContent,
rawContent: commonStore.recordingRawContent,
contentTime: contentTime,
mainInstrument: mainInstrument
contentTime: getMidiRawContentTime(commonStore.recordingRawContent),
mainInstrument: getMidiRawContentMainInstrument(commonStore.recordingRawContent)
};
commonStore.setTracks(tracks);
refreshTracksTotalTime();