avoid overflow

This commit is contained in:
josc146 2023-06-21 22:02:42 +08:00
parent 3fe9ef4546
commit cf0972ba52
2 changed files with 10 additions and 2 deletions

View File

@ -167,9 +167,10 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
});
const strategy = getStrategy(modelConfig);
let customCudaFile = '';
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom')
&& modelConfig.modelParameters.useCustomCuda && modelConfig.modelParameters.precision != 'fp32') {
&& modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
if (commonStore.platform === 'windows') {
customCudaFile = getSupportedCustomCudaFile();
if (customCudaFile) {
@ -194,7 +195,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
switchModel({
model: modelPath,
strategy: getStrategy(modelConfig),
strategy: strategy,
customCuda: customCudaFile !== ''
}).then(async (r) => {
if (r.ok) {

View File

@ -126,19 +126,26 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
let params: ModelParameters;
if (modelConfig) params = modelConfig.modelParameters;
else params = commonStore.getCurrentModelConfig().modelParameters;
const avoidOverflow = params.modelName.toLowerCase().includes('world') && params.precision !== 'fp32';
let strategy = '';
switch (params.device) {
case 'CPU':
if (avoidOverflow)
strategy = 'cpu fp32 *1 -> ';
strategy += 'cpu ';
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break;
case 'CUDA':
if (avoidOverflow)
strategy = 'cuda fp32 *1 -> ';
strategy += 'cuda ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
if (params.storedLayers < params.maxStoredLayers)
strategy += ` *${params.storedLayers}+`;
break;
case 'MPS':
if (avoidOverflow)
strategy = 'mps fp32 *1 -> ';
strategy += 'mps ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
break;