diff --git a/frontend/src/pages/Chat.tsx b/frontend/src/pages/Chat.tsx
index 6fed0ce..a1c030f 100644
--- a/frontend/src/pages/Chat.tsx
+++ b/frontend/src/pages/Chat.tsx
@@ -28,7 +28,7 @@ import { toast } from 'react-toastify';
import { WorkHeader } from '../components/WorkHeader';
import { DialogButton } from '../components/DialogButton';
import { OpenFileFolder, OpenOpenFileDialog, OpenSaveFileDialog } from '../../wailsjs/go/backend_golang/App';
-import { absPathAsset, bytesToReadable, getServerRoot, toastWithButton } from '../utils';
+import { absPathAsset, bytesToReadable, getServerRoot, setActivePreset, toastWithButton } from '../utils';
import { useMediaQuery } from 'usehooks-ts';
import { botName, ConversationMessage, MessageType, userName, welcomeUuid } from '../types/chat';
import { Labeled } from '../components/Labeled';
@@ -536,8 +536,7 @@ const ChatPanel: FC = observer(() => {
}
chatSseControllers = {};
}
- commonStore.setConversation({});
- commonStore.setConversationOrder([]);
+ setActivePreset(commonStore.activePreset);
}} />
;
});
-const pages: { [label: string]: PresetsNavigationItem } = {
+const pages: {
+ [label: string]: PresetsNavigationItem
+} = {
Chat: {
icon: ,
element:
@@ -395,7 +372,9 @@ const pages: { [label: string]: PresetsNavigationItem } = {
}
};
-const PresetsManager: FC<{ initTab: string }> = ({ initTab }) => {
+const PresetsManager: FC<{
+ initTab: string
+}> = ({ initTab }) => {
const { t } = useTranslation();
const [tab, setTab] = useState(initTab);
diff --git a/frontend/src/stores/commonStore.ts b/frontend/src/stores/commonStore.ts
index f3714ca..5be7a1a 100644
--- a/frontend/src/stores/commonStore.ts
+++ b/frontend/src/stores/commonStore.ts
@@ -70,7 +70,9 @@ class CommonStore {
conversationOrder: string[] = [];
activePreset: Preset | null = null;
attachmentUploading: boolean = false;
- attachments: { [uuid: string]: Attachment[] } = {};
+ attachments: {
+ [uuid: string]: Attachment[]
+ } = {};
currentTempAttachment: Attachment | null = null;
chatParams: ChatParams = {
maxResponseToken: 1000,
@@ -327,7 +329,7 @@ class CommonStore {
savePresets();
}
- setActivePreset(value: Preset) {
+ setActivePreset(value: Preset | null) {
this.activePreset = value;
}
@@ -379,7 +381,9 @@ class CommonStore {
this.attachmentUploading = value;
}
- setAttachments(value: { [uuid: string]: Attachment[] }) {
+ setAttachments(value: {
+ [uuid: string]: Attachment[]
+ }) {
this.attachments = value;
}
diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx
index 58c4d61..674b43f 100644
--- a/frontend/src/utils/index.tsx
+++ b/frontend/src/utils/index.tsx
@@ -24,6 +24,9 @@ import { Language, Languages, SettingsType } from '../types/settings';
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
import { InstrumentTypeNameMap, tracksMinimalTotalTime } from '../types/composition';
import logo from '../assets/images/logo.png';
+import { Preset } from '../types/presets';
+import { botName, Conversation, MessageType, userName } from '../types/chat';
+import { v4 as uuid } from 'uuid';
export type Cache = {
version: string
@@ -41,7 +44,9 @@ export type LocalConfig = {
}
export async function refreshBuiltInModels(readCache: boolean = false) {
- let cache: { models: ModelSourceItem[] } = { models: [] };
+ let cache: {
+ models: ModelSourceItem[]
+ } = { models: [] };
if (readCache)
await ReadJson('cache.json').then((cacheData: Cache) => {
if (cacheData.models)
@@ -133,7 +138,9 @@ function initLastUnfinishedModelDownloads() {
commonStore.setLastUnfinishedModelDownloads(list);
}
-export async function refreshRemoteModels(cache: { models: ModelSourceItem[] }) {
+export async function refreshRemoteModels(cache: {
+ models: ModelSourceItem[]
+}) {
const manifestUrls = commonStore.modelSourceManifestList.split(/[,,;;\n]/);
const requests = manifestUrls.filter(url => url.endsWith('.json')).map(
url => fetch(url, { cache: 'no-cache' }).then(r => r.json()));
@@ -515,7 +522,9 @@ export function flushMidiRecordingContent() {
.filter(c => c.messageType === 'NoteOn')
.map(c => c.instrument)
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
- , {} as { [key: string]: number }))
+ , {} as {
+ [key: string]: number
+ }))
.sort((a, b) => b[1] - a[1]);
let mainInstrument: string = '';
if (sortedInstrumentFrequency.length > 0)
@@ -551,6 +560,30 @@ export async function getSoundFont() {
return soundUrl;
}
+export const setActivePreset = (preset: Preset | null) => {
+ commonStore.setActivePreset(preset);
+ //TODO if (preset.displayPresetMessages) {
+ const conversation: Conversation = {};
+ const conversationOrder: string[] = [];
+ if (preset)
+ for (const message of preset.messages) {
+ const newUuid = uuid();
+ conversationOrder.push(newUuid);
+ conversation[newUuid] = {
+ sender: message.role === 'user' ? userName : botName,
+ type: MessageType.Normal,
+ color: message.role === 'user' ? 'brand' : 'colorful',
+ time: new Date().toISOString(),
+ content: message.content,
+ side: message.role === 'user' ? 'right' : 'left',
+ done: true
+ };
+ }
+ commonStore.setConversation(conversation);
+ commonStore.setConversationOrder(conversationOrder);
+ //}
+};
+
export function getSupportedCustomCudaFile(isBeta: boolean) {
if ([' 10', ' 16', ' 20', ' 30', 'MX', 'Tesla P', 'Quadro P', 'NVIDIA P', 'TITAN X', 'TITAN RTX', 'RTX A',
'Quadro RTX 4000', 'Quadro RTX 5000', 'Tesla T4', 'NVIDIA A10', 'NVIDIA A40'].some(v => commonStore.status.device_name.includes(v)))