From e01897b24dddcdc9e6561907a59ae8c77f1f475d Mon Sep 17 00:00:00 2001 From: josc146 Date: Fri, 24 Nov 2023 19:21:14 +0800 Subject: [PATCH] improve launch flow of webgpu mode --- frontend/src/components/RunButton.tsx | 38 +++++++++++++++------------ 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index 209aa14..4303a38 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -58,11 +58,32 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean return; } + const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName); + + const showDownloadPrompt = (promptInfo: string, downloadName: string) => { + toastWithButton(promptInfo, t('Download'), () => { + const downloadUrl = currentModelSource?.downloadUrl; + if (downloadUrl) { + toastWithButton(`${t('Downloading')} ${downloadName}`, t('Check'), () => { + navigate({ pathname: '/downloads' }); + }, + { autoClose: 3000 }); + AddToDownloadList(modelPath, getHfDownloadUrl(downloadUrl)); + } else { + toast(t('Can not find download url'), { type: 'error' }); + } + }); + }; + if (webgpu) { if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) { const stModelPath = modelPath.replace(/\.pth$/, '.st'); if (await FileExists(stModelPath)) { modelPath = stModelPath; + } else if (!await FileExists(modelPath)) { + showDownloadPrompt(t('Model file not found'), modelName); + commonStore.setStatus({ status: ModelStatus.Offline }); + return; } else { toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => { convertToSt(navigate, modelConfig); @@ -87,23 +108,6 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean return; } - const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName); - - const showDownloadPrompt = (promptInfo: string, downloadName: string) => { - toastWithButton(promptInfo, t('Download'), () => { - const downloadUrl = currentModelSource?.downloadUrl; - if (downloadUrl) { - toastWithButton(`${t('Downloading')} ${downloadName}`, t('Check'), () => { - navigate({ pathname: '/downloads' }); - }, - { autoClose: 3000 }); - AddToDownloadList(modelPath, getHfDownloadUrl(downloadUrl)); - } else { - toast(t('Can not find download url'), { type: 'error' }); - } - }); - }; - if (!await FileExists(modelPath)) { showDownloadPrompt(t('Model file not found'), modelName); commonStore.setStatus({ status: ModelStatus.Offline });