avoid overflow

This commit is contained in:
josc146 2023-06-21 22:02:42 +08:00
parent 3fe9ef4546
commit cf0972ba52
2 changed files with 10 additions and 2 deletions

View File

@ -167,9 +167,10 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
frequency_penalty: modelConfig.apiParameters.frequencyPenalty frequency_penalty: modelConfig.apiParameters.frequencyPenalty
}); });
const strategy = getStrategy(modelConfig);
let customCudaFile = ''; let customCudaFile = '';
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom') if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom')
&& modelConfig.modelParameters.useCustomCuda && modelConfig.modelParameters.precision != 'fp32') { && modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
if (commonStore.platform === 'windows') { if (commonStore.platform === 'windows') {
customCudaFile = getSupportedCustomCudaFile(); customCudaFile = getSupportedCustomCudaFile();
if (customCudaFile) { if (customCudaFile) {
@ -194,7 +195,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
switchModel({ switchModel({
model: modelPath, model: modelPath,
strategy: getStrategy(modelConfig), strategy: strategy,
customCuda: customCudaFile !== '' customCuda: customCudaFile !== ''
}).then(async (r) => { }).then(async (r) => {
if (r.ok) { if (r.ok) {

View File

@ -126,19 +126,26 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
let params: ModelParameters; let params: ModelParameters;
if (modelConfig) params = modelConfig.modelParameters; if (modelConfig) params = modelConfig.modelParameters;
else params = commonStore.getCurrentModelConfig().modelParameters; else params = commonStore.getCurrentModelConfig().modelParameters;
const avoidOverflow = params.modelName.toLowerCase().includes('world') && params.precision !== 'fp32';
let strategy = ''; let strategy = '';
switch (params.device) { switch (params.device) {
case 'CPU': case 'CPU':
if (avoidOverflow)
strategy = 'cpu fp32 *1 -> ';
strategy += 'cpu '; strategy += 'cpu ';
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32'; strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break; break;
case 'CUDA': case 'CUDA':
if (avoidOverflow)
strategy = 'cuda fp32 *1 -> ';
strategy += 'cuda '; strategy += 'cuda ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'; strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
if (params.storedLayers < params.maxStoredLayers) if (params.storedLayers < params.maxStoredLayers)
strategy += ` *${params.storedLayers}+`; strategy += ` *${params.storedLayers}+`;
break; break;
case 'MPS': case 'MPS':
if (avoidOverflow)
strategy = 'mps fp32 *1 -> ';
strategy += 'mps '; strategy += 'mps ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'; strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
break; break;