add customCudaFile support

This commit is contained in:
josc146
2023-05-23 14:04:06 +08:00
parent 65d92d5da1
commit 4eca1537a7
6 changed files with 56 additions and 3 deletions

View File

@@ -99,5 +99,7 @@
"Model Config Exception": "模型配置异常",
"Use Gitee Updates Source": "使用Gitee更新源",
"Use Custom CUDA kernel to Accelerate": "使用自定义CUDA算子加速",
"Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项"
"Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项",
"Supported custom cuda file not found": "没有找到支持的自定义cuda文件",
"Failed to copy custom cuda file": "自定义cuda文件复制失败"
}

View File

@@ -2,6 +2,7 @@ import React, { FC, MouseEventHandler, ReactElement } from 'react';
import commonStore, { ModelStatus } from '../stores/commonStore';
import {
AddToDownloadList,
CopyFile,
DepCheck,
FileExists,
InstallPyDep,
@@ -12,7 +13,7 @@ import { observer } from 'mobx-react-lite';
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
import { toast } from 'react-toastify';
import manifest from '../../../manifest.json';
import { getStrategy, saveCache, toastWithButton } from '../utils';
import { getStrategy, getSupportedCustomCudaFile, saveCache, toastWithButton } from '../utils';
import { useTranslation } from 'react-i18next';
import { ToolTipButton } from './ToolTipButton';
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
@@ -83,6 +84,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
commonStore.setDepComplete(true);
CopyFile('./backend-python/wkv_cuda_utils/wkv_cuda_model.py', './py310/Lib/site-packages/rwkv/model.py');
saveCache();
}
@@ -132,10 +134,23 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
presence_penalty: modelConfig.apiParameters.presencePenalty,
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
});
let customCudaFile = '';
if (modelConfig.modelParameters.useCustomCuda) {
customCudaFile = getSupportedCustomCudaFile();
if (customCudaFile)
await CopyFile(customCudaFile, './py310/Lib/site-packages/rwkv/wkv_cuda.pyd').catch(() => {
customCudaFile = '';
toast(t('Failed to copy custom cuda file'), { type: 'error' });
});
else
toast(t('Supported custom cuda file not found'), { type: 'warning' });
}
switchModel({
model: `${manifest.localModelDir}/${modelConfig.modelParameters.modelName}`,
strategy: getStrategy(modelConfig),
customCuda: !!modelConfig.modelParameters.useCustomCuda
customCuda: customCudaFile !== ''
}).then((r) => {
if (r.ok) {
commonStore.setStatus({ modelStatus: ModelStatus.Working });

View File

@@ -267,4 +267,13 @@ export function toastWithButton(text: string, buttonText: string, onClickButton:
type: 'info',
...options
});
}
export function getSupportedCustomCudaFile() {
if ([' 10', ' 20', ' 30'].some(v => commonStore.status.device_name.includes(v)))
return './backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd';
else if ([' 40'].some(v => commonStore.status.device_name.includes(v)))
return './backend-python/wkv_cuda_utils/wkv_cuda40.pyd';
else
return '';
}