diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index 9c1cb97..f80c8b6 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -19,6 +19,7 @@ import { useNavigate } from 'react-router'; import { WindowShow } from '../../wailsjs/runtime'; import { convertToGGML, convertToSt } from '../utils/convert-model'; import { Precision } from '../types/configs'; +import { defaultCompositionABCPrompt, defaultCompositionPrompt } from '../pages/defaultConfigs'; const mainButtonText = { [ModelStatus.Offline]: 'Run', @@ -257,6 +258,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean commonStore.setStatus({ status: ModelStatus.Working }); let buttonNameMap = { 'novel': 'Completion', + 'abc': 'Composition', 'midi': 'Composition' }; let buttonName = 'Chat'; @@ -264,6 +266,13 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean const buttonFn = () => { navigate({ pathname: '/' + buttonName.toLowerCase() }); }; + if (modelName.toLowerCase().includes('abc') && commonStore.compositionParams.prompt === defaultCompositionPrompt) { + commonStore.setCompositionParams({ + ...commonStore.compositionParams, + prompt: defaultCompositionABCPrompt + }); + commonStore.setCompositionSubmittedPrompt(defaultCompositionABCPrompt); + } if (modelConfig.modelParameters.device.startsWith('CUDA') && modelConfig.modelParameters.storedLayers < modelConfig.modelParameters.maxStoredLayers && diff --git a/frontend/src/pages/Composition.tsx b/frontend/src/pages/Composition.tsx index 285d124..caaaf09 100644 --- a/frontend/src/pages/Composition.tsx +++ b/frontend/src/pages/Composition.tsx @@ -15,7 +15,7 @@ import { ArrowSync20Regular, Save28Regular } from '@fluentui/react-icons'; import { PlayerElement, VisualizerElement } from 'html-midi-player'; import * as mm from '@magenta/music/esm/core.js'; import { NoteSequence } from '@magenta/music/esm/protobuf.js'; -import { defaultCompositionPrompt } from './defaultConfigs'; +import { defaultCompositionABCPrompt, defaultCompositionPrompt } from './defaultConfigs'; import { CloseMidiPort, FileExists, @@ -370,11 +370,13 @@ const CompositionPanel: FC = observer(() => { { - commonStore.setCompositionSubmittedPrompt(defaultCompositionPrompt); + const isABC = commonStore.getCurrentModelConfig().modelParameters.modelName.toLowerCase().includes('abc'); + const defaultPrompt = isABC ? defaultCompositionABCPrompt : defaultCompositionPrompt; + commonStore.setCompositionSubmittedPrompt(defaultPrompt); setParams({ generationStartTime: 0 }); - setPrompt(defaultCompositionPrompt); + setPrompt(defaultPrompt); }} />