add webgpu nf4

This commit is contained in:
josc146 2023-11-20 21:10:10 +08:00
parent d435436525
commit 48fef0235b
3 changed files with 7 additions and 5 deletions

View File

@ -340,8 +340,10 @@ const Configs: FC = observer(() => {
}); });
} }
}}> }}>
<Option>fp16</Option> {selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
<Option>fp16</Option>}
<Option>int8</Option> <Option>int8</Option>
{selectedConfig.modelParameters.device === 'WebGPU' && <Option>nf4</Option>}
{selectedConfig.modelParameters.device !== 'WebGPU' && <Option>fp32</Option>} {selectedConfig.modelParameters.device !== 'WebGPU' && <Option>fp32</Option>}
</Dropdown> </Dropdown>
} /> } />

View File

@ -7,7 +7,7 @@ export type ApiParameters = {
frequencyPenalty: number; frequencyPenalty: number;
} }
export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom'; export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32'; export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4';
export type ModelParameters = { export type ModelParameters = {
// different models can not have the same name // different models can not have the same name
modelName: string; modelName: string;

View File

@ -178,14 +178,14 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32'; strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break; break;
case 'WebGPU': case 'WebGPU':
strategy += params.precision === 'int8' ? 'fp16i8' : 'fp16'; strategy += params.precision === 'nf4' ? 'fp16i4' : params.precision === 'int8' ? 'fp16i8' : 'fp16';
break; break;
case 'CUDA': case 'CUDA':
case 'CUDA-Beta': case 'CUDA-Beta':
if (avoidOverflow) if (avoidOverflow)
strategy = params.useCustomCuda ? 'cuda fp16 *1 -> ' : 'cuda fp32 *1 -> '; strategy = params.useCustomCuda ? 'cuda fp16 *1 -> ' : 'cuda fp32 *1 -> ';
strategy += 'cuda '; strategy += 'cuda ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'; strategy += params.precision === 'int8' ? 'fp16i8' : params.precision === 'fp32' ? 'fp32' : 'fp16';
if (params.storedLayers < params.maxStoredLayers) if (params.storedLayers < params.maxStoredLayers)
strategy += ` *${params.storedLayers}+`; strategy += ` *${params.storedLayers}+`;
break; break;
@ -193,7 +193,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
if (avoidOverflow) if (avoidOverflow)
strategy = 'mps fp32 *1 -> '; strategy = 'mps fp32 *1 -> ';
strategy += 'mps '; strategy += 'mps ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'; strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break; break;
case 'Custom': case 'Custom':
strategy = params.customStrategy || ''; strategy = params.customStrategy || '';