add customCudaFile support
This commit is contained in:
parent
65d92d5da1
commit
4eca1537a7
@ -3,6 +3,7 @@ package backend_golang
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -92,6 +93,26 @@ func (a *App) DeleteFile(path string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) CopyFile(src string, dst string) error {
|
||||||
|
sourceFile, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sourceFile.Close()
|
||||||
|
|
||||||
|
destFile, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer destFile.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(sourceFile, destFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) OpenFileFolder(path string) error {
|
func (a *App) OpenFileFolder(path string) error {
|
||||||
absPath, err := filepath.Abs(path)
|
absPath, err := filepath.Abs(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -99,5 +99,7 @@
|
|||||||
"Model Config Exception": "模型配置异常",
|
"Model Config Exception": "模型配置异常",
|
||||||
"Use Gitee Updates Source": "使用Gitee更新源",
|
"Use Gitee Updates Source": "使用Gitee更新源",
|
||||||
"Use Custom CUDA kernel to Accelerate": "使用自定义CUDA算子加速",
|
"Use Custom CUDA kernel to Accelerate": "使用自定义CUDA算子加速",
|
||||||
"Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项"
|
"Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项",
|
||||||
|
"Supported custom cuda file not found": "没有找到支持的自定义cuda文件",
|
||||||
|
"Failed to copy custom cuda file": "自定义cuda文件复制失败"
|
||||||
}
|
}
|
@ -2,6 +2,7 @@ import React, { FC, MouseEventHandler, ReactElement } from 'react';
|
|||||||
import commonStore, { ModelStatus } from '../stores/commonStore';
|
import commonStore, { ModelStatus } from '../stores/commonStore';
|
||||||
import {
|
import {
|
||||||
AddToDownloadList,
|
AddToDownloadList,
|
||||||
|
CopyFile,
|
||||||
DepCheck,
|
DepCheck,
|
||||||
FileExists,
|
FileExists,
|
||||||
InstallPyDep,
|
InstallPyDep,
|
||||||
@ -12,7 +13,7 @@ import { observer } from 'mobx-react-lite';
|
|||||||
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
|
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
|
||||||
import { toast } from 'react-toastify';
|
import { toast } from 'react-toastify';
|
||||||
import manifest from '../../../manifest.json';
|
import manifest from '../../../manifest.json';
|
||||||
import { getStrategy, saveCache, toastWithButton } from '../utils';
|
import { getStrategy, getSupportedCustomCudaFile, saveCache, toastWithButton } from '../utils';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ToolTipButton } from './ToolTipButton';
|
import { ToolTipButton } from './ToolTipButton';
|
||||||
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
|
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
|
||||||
@ -83,6 +84,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
commonStore.setDepComplete(true);
|
commonStore.setDepComplete(true);
|
||||||
|
CopyFile('./backend-python/wkv_cuda_utils/wkv_cuda_model.py', './py310/Lib/site-packages/rwkv/model.py');
|
||||||
saveCache();
|
saveCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,10 +134,23 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
presence_penalty: modelConfig.apiParameters.presencePenalty,
|
presence_penalty: modelConfig.apiParameters.presencePenalty,
|
||||||
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
|
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let customCudaFile = '';
|
||||||
|
if (modelConfig.modelParameters.useCustomCuda) {
|
||||||
|
customCudaFile = getSupportedCustomCudaFile();
|
||||||
|
if (customCudaFile)
|
||||||
|
await CopyFile(customCudaFile, './py310/Lib/site-packages/rwkv/wkv_cuda.pyd').catch(() => {
|
||||||
|
customCudaFile = '';
|
||||||
|
toast(t('Failed to copy custom cuda file'), { type: 'error' });
|
||||||
|
});
|
||||||
|
else
|
||||||
|
toast(t('Supported custom cuda file not found'), { type: 'warning' });
|
||||||
|
}
|
||||||
|
|
||||||
switchModel({
|
switchModel({
|
||||||
model: `${manifest.localModelDir}/${modelConfig.modelParameters.modelName}`,
|
model: `${manifest.localModelDir}/${modelConfig.modelParameters.modelName}`,
|
||||||
strategy: getStrategy(modelConfig),
|
strategy: getStrategy(modelConfig),
|
||||||
customCuda: !!modelConfig.modelParameters.useCustomCuda
|
customCuda: customCudaFile !== ''
|
||||||
}).then((r) => {
|
}).then((r) => {
|
||||||
if (r.ok) {
|
if (r.ok) {
|
||||||
commonStore.setStatus({ modelStatus: ModelStatus.Working });
|
commonStore.setStatus({ modelStatus: ModelStatus.Working });
|
||||||
|
@ -267,4 +267,13 @@ export function toastWithButton(text: string, buttonText: string, onClickButton:
|
|||||||
type: 'info',
|
type: 'info',
|
||||||
...options
|
...options
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSupportedCustomCudaFile() {
|
||||||
|
if ([' 10', ' 20', ' 30'].some(v => commonStore.status.device_name.includes(v)))
|
||||||
|
return './backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd';
|
||||||
|
else if ([' 40'].some(v => commonStore.status.device_name.includes(v)))
|
||||||
|
return './backend-python/wkv_cuda_utils/wkv_cuda40.pyd';
|
||||||
|
else
|
||||||
|
return '';
|
||||||
}
|
}
|
2
frontend/wailsjs/go/backend_golang/App.d.ts
vendored
2
frontend/wailsjs/go/backend_golang/App.d.ts
vendored
@ -8,6 +8,8 @@ export function ContinueDownload(arg1:string):Promise<void>;
|
|||||||
|
|
||||||
export function ConvertModel(arg1:string,arg2:string,arg3:string):Promise<string>;
|
export function ConvertModel(arg1:string,arg2:string,arg3:string):Promise<string>;
|
||||||
|
|
||||||
|
export function CopyFile(arg1:string,arg2:string):Promise<void>;
|
||||||
|
|
||||||
export function DeleteFile(arg1:string):Promise<void>;
|
export function DeleteFile(arg1:string):Promise<void>;
|
||||||
|
|
||||||
export function DepCheck():Promise<void>;
|
export function DepCheck():Promise<void>;
|
||||||
|
@ -14,6 +14,10 @@ export function ConvertModel(arg1, arg2, arg3) {
|
|||||||
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3);
|
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function CopyFile(arg1, arg2) {
|
||||||
|
return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2);
|
||||||
|
}
|
||||||
|
|
||||||
export function DeleteFile(arg1) {
|
export function DeleteFile(arg1) {
|
||||||
return window['go']['backend_golang']['App']['DeleteFile'](arg1);
|
return window['go']['backend_golang']['App']['DeleteFile'](arg1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user