This commit is contained in:
josc146
2023-12-14 18:37:07 +08:00
parent 01c95f5bc4
commit 0ddd2e9fea
16 changed files with 155 additions and 34 deletions

View File

@@ -48,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const modelConfig = commonStore.getCurrentModelConfig();
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
const webgpuPython = modelConfig.modelParameters.device === 'WebGPU (Python)';
const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)';
let modelName = '';
let modelPath = '';
@@ -77,7 +78,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
});
};
if (webgpu) {
if (webgpu || webgpuPython) {
if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
const stModelPath = modelPath.replace(/\.pth$/, '.st');
if (await FileExists(stModelPath)) {
@@ -92,7 +93,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
} else {
toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => {
convertToSt(modelConfig);
convertToSt(modelConfig, navigate);
});
commonStore.setStatus({ status: ModelStatus.Offline });
return;
@@ -100,7 +101,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
}
}
if (!webgpu) {
if (!webgpu && !webgpuPython) {
if (['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
toast(t('Please change Strategy to WebGPU to use safetensors format'), { type: 'error' });
commonStore.setStatus({ status: ModelStatus.Offline });
@@ -176,7 +177,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta';
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp, webgpuPython
).catch((e) => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
@@ -216,7 +217,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const strategy = getStrategy(modelConfig);
let customCudaFile = '';
if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom')
if ((modelConfig.modelParameters.device.startsWith('CUDA') || modelConfig.modelParameters.device === 'Custom')
&& modelConfig.modelParameters.useCustomCuda
&& !strategy.split('->').some(s => ['cuda', 'fp32'].every(v => s.includes(v)))) {
if (commonStore.platform === 'windows') {
@@ -264,7 +265,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
navigate({ pathname: '/' + buttonName.toLowerCase() });
};
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'CUDA-Beta') &&
if (modelConfig.modelParameters.device.startsWith('CUDA') &&
modelConfig.modelParameters.storedLayers < modelConfig.modelParameters.maxStoredLayers &&
commonStore.monitorData && commonStore.monitorData.totalVram !== 0 &&
(commonStore.monitorData.usedVram / commonStore.monitorData.totalVram) < 0.9)