diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index cd430e7..76a9733 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -167,9 +167,10 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean frequency_penalty: modelConfig.apiParameters.frequencyPenalty }); + const strategy = getStrategy(modelConfig); let customCudaFile = ''; if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom') - && modelConfig.modelParameters.useCustomCuda && modelConfig.modelParameters.precision != 'fp32') { + && modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) { if (commonStore.platform === 'windows') { customCudaFile = getSupportedCustomCudaFile(); if (customCudaFile) { @@ -194,7 +195,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean switchModel({ model: modelPath, - strategy: getStrategy(modelConfig), + strategy: strategy, customCuda: customCudaFile !== '' }).then(async (r) => { if (r.ok) { diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx index 8294e52..4707aee 100644 --- a/frontend/src/utils/index.tsx +++ b/frontend/src/utils/index.tsx @@ -126,19 +126,26 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) => let params: ModelParameters; if (modelConfig) params = modelConfig.modelParameters; else params = commonStore.getCurrentModelConfig().modelParameters; + const avoidOverflow = params.modelName.toLowerCase().includes('world') && params.precision !== 'fp32'; let strategy = ''; switch (params.device) { case 'CPU': + if (avoidOverflow) + strategy = 'cpu fp32 *1 -> '; strategy += 'cpu '; strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32'; break; case 'CUDA': + if (avoidOverflow) + strategy = 'cuda fp32 *1 -> '; strategy += 'cuda '; strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'; if (params.storedLayers < params.maxStoredLayers) strategy += ` *${params.storedLayers}+`; break; case 'MPS': + if (avoidOverflow) + strategy = 'mps fp32 *1 -> '; strategy += 'mps '; strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'; break;