improve launch flow of webgpu mode

This commit is contained in:
josc146 2023-11-24 19:21:14 +08:00
parent 6146d910b4
commit e01897b24d

View File

@ -58,11 +58,32 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
const showDownloadPrompt = (promptInfo: string, downloadName: string) => {
toastWithButton(promptInfo, t('Download'), () => {
const downloadUrl = currentModelSource?.downloadUrl;
if (downloadUrl) {
toastWithButton(`${t('Downloading')} ${downloadName}`, t('Check'), () => {
navigate({ pathname: '/downloads' });
},
{ autoClose: 3000 });
AddToDownloadList(modelPath, getHfDownloadUrl(downloadUrl));
} else {
toast(t('Can not find download url'), { type: 'error' });
}
});
};
if (webgpu) {
if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
const stModelPath = modelPath.replace(/\.pth$/, '.st');
if (await FileExists(stModelPath)) {
modelPath = stModelPath;
} else if (!await FileExists(modelPath)) {
showDownloadPrompt(t('Model file not found'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
} else {
toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => {
convertToSt(navigate, modelConfig);
@ -87,23 +108,6 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
const showDownloadPrompt = (promptInfo: string, downloadName: string) => {
toastWithButton(promptInfo, t('Download'), () => {
const downloadUrl = currentModelSource?.downloadUrl;
if (downloadUrl) {
toastWithButton(`${t('Downloading')} ${downloadName}`, t('Check'), () => {
navigate({ pathname: '/downloads' });
},
{ autoClose: 3000 });
AddToDownloadList(modelPath, getHfDownloadUrl(downloadUrl));
} else {
toast(t('Can not find download url'), { type: 'error' });
}
});
};
if (!await FileExists(modelPath)) {
showDownloadPrompt(t('Model file not found'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });