avoid overflow
This commit is contained in:
parent
3fe9ef4546
commit
cf0972ba52
@ -167,9 +167,10 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
|
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const strategy = getStrategy(modelConfig);
|
||||||
let customCudaFile = '';
|
let customCudaFile = '';
|
||||||
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom')
|
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom')
|
||||||
&& modelConfig.modelParameters.useCustomCuda && modelConfig.modelParameters.precision != 'fp32') {
|
&& modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
|
||||||
if (commonStore.platform === 'windows') {
|
if (commonStore.platform === 'windows') {
|
||||||
customCudaFile = getSupportedCustomCudaFile();
|
customCudaFile = getSupportedCustomCudaFile();
|
||||||
if (customCudaFile) {
|
if (customCudaFile) {
|
||||||
@ -194,7 +195,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
|
|
||||||
switchModel({
|
switchModel({
|
||||||
model: modelPath,
|
model: modelPath,
|
||||||
strategy: getStrategy(modelConfig),
|
strategy: strategy,
|
||||||
customCuda: customCudaFile !== ''
|
customCuda: customCudaFile !== ''
|
||||||
}).then(async (r) => {
|
}).then(async (r) => {
|
||||||
if (r.ok) {
|
if (r.ok) {
|
||||||
|
@ -126,19 +126,26 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
|
|||||||
let params: ModelParameters;
|
let params: ModelParameters;
|
||||||
if (modelConfig) params = modelConfig.modelParameters;
|
if (modelConfig) params = modelConfig.modelParameters;
|
||||||
else params = commonStore.getCurrentModelConfig().modelParameters;
|
else params = commonStore.getCurrentModelConfig().modelParameters;
|
||||||
|
const avoidOverflow = params.modelName.toLowerCase().includes('world') && params.precision !== 'fp32';
|
||||||
let strategy = '';
|
let strategy = '';
|
||||||
switch (params.device) {
|
switch (params.device) {
|
||||||
case 'CPU':
|
case 'CPU':
|
||||||
|
if (avoidOverflow)
|
||||||
|
strategy = 'cpu fp32 *1 -> ';
|
||||||
strategy += 'cpu ';
|
strategy += 'cpu ';
|
||||||
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
||||||
break;
|
break;
|
||||||
case 'CUDA':
|
case 'CUDA':
|
||||||
|
if (avoidOverflow)
|
||||||
|
strategy = 'cuda fp32 *1 -> ';
|
||||||
strategy += 'cuda ';
|
strategy += 'cuda ';
|
||||||
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
|
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
|
||||||
if (params.storedLayers < params.maxStoredLayers)
|
if (params.storedLayers < params.maxStoredLayers)
|
||||||
strategy += ` *${params.storedLayers}+`;
|
strategy += ` *${params.storedLayers}+`;
|
||||||
break;
|
break;
|
||||||
case 'MPS':
|
case 'MPS':
|
||||||
|
if (avoidOverflow)
|
||||||
|
strategy = 'mps fp32 *1 -> ';
|
||||||
strategy += 'mps ';
|
strategy += 'mps ';
|
||||||
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
|
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
|
||||||
break;
|
break;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user