diff --git a/backend-golang/file.go b/backend-golang/file.go index 7e54b8b..02ebe34 100644 --- a/backend-golang/file.go +++ b/backend-golang/file.go @@ -3,6 +3,7 @@ package backend_golang import ( "encoding/json" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -92,6 +93,26 @@ func (a *App) DeleteFile(path string) error { return nil } +func (a *App) CopyFile(src string, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(dst) + if err != nil { + return err + } + defer destFile.Close() + + _, err = io.Copy(sourceFile, destFile) + if err != nil { + return err + } + return nil +} + func (a *App) OpenFileFolder(path string) error { absPath, err := filepath.Abs(path) if err != nil { diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index 5d5bcf8..4a9c33e 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -99,5 +99,7 @@ "Model Config Exception": "模型配置异常", "Use Gitee Updates Source": "使用Gitee更新源", "Use Custom CUDA kernel to Accelerate": "使用自定义CUDA算子加速", - "Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项" + "Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项", + "Supported custom cuda file not found": "没有找到支持的自定义cuda文件", + "Failed to copy custom cuda file": "自定义cuda文件复制失败" } \ No newline at end of file diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index aa94f2d..07d5474 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -2,6 +2,7 @@ import React, { FC, MouseEventHandler, ReactElement } from 'react'; import commonStore, { ModelStatus } from '../stores/commonStore'; import { AddToDownloadList, + CopyFile, DepCheck, FileExists, InstallPyDep, @@ -12,7 +13,7 @@ import { observer } from 'mobx-react-lite'; import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis'; import { toast } from 'react-toastify'; import manifest from '../../../manifest.json'; -import { getStrategy, saveCache, toastWithButton } from '../utils'; +import { getStrategy, getSupportedCustomCudaFile, saveCache, toastWithButton } from '../utils'; import { useTranslation } from 'react-i18next'; import { ToolTipButton } from './ToolTipButton'; import { Play16Regular, Stop16Regular } from '@fluentui/react-icons'; @@ -83,6 +84,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean return; } commonStore.setDepComplete(true); + CopyFile('./backend-python/wkv_cuda_utils/wkv_cuda_model.py', './py310/Lib/site-packages/rwkv/model.py'); saveCache(); } @@ -132,10 +134,23 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean presence_penalty: modelConfig.apiParameters.presencePenalty, frequency_penalty: modelConfig.apiParameters.frequencyPenalty }); + + let customCudaFile = ''; + if (modelConfig.modelParameters.useCustomCuda) { + customCudaFile = getSupportedCustomCudaFile(); + if (customCudaFile) + await CopyFile(customCudaFile, './py310/Lib/site-packages/rwkv/wkv_cuda.pyd').catch(() => { + customCudaFile = ''; + toast(t('Failed to copy custom cuda file'), { type: 'error' }); + }); + else + toast(t('Supported custom cuda file not found'), { type: 'warning' }); + } + switchModel({ model: `${manifest.localModelDir}/${modelConfig.modelParameters.modelName}`, strategy: getStrategy(modelConfig), - customCuda: !!modelConfig.modelParameters.useCustomCuda + customCuda: customCudaFile !== '' }).then((r) => { if (r.ok) { commonStore.setStatus({ modelStatus: ModelStatus.Working }); diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx index 3dd3d10..6bcb856 100644 --- a/frontend/src/utils/index.tsx +++ b/frontend/src/utils/index.tsx @@ -267,4 +267,13 @@ export function toastWithButton(text: string, buttonText: string, onClickButton: type: 'info', ...options }); +} + +export function getSupportedCustomCudaFile() { + if ([' 10', ' 20', ' 30'].some(v => commonStore.status.device_name.includes(v))) + return './backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd'; + else if ([' 40'].some(v => commonStore.status.device_name.includes(v))) + return './backend-python/wkv_cuda_utils/wkv_cuda40.pyd'; + else + return ''; } \ No newline at end of file diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts index 9cf3c8c..242d251 100644 --- a/frontend/wailsjs/go/backend_golang/App.d.ts +++ b/frontend/wailsjs/go/backend_golang/App.d.ts @@ -8,6 +8,8 @@ export function ContinueDownload(arg1:string):Promise; export function ConvertModel(arg1:string,arg2:string,arg3:string):Promise; +export function CopyFile(arg1:string,arg2:string):Promise; + export function DeleteFile(arg1:string):Promise; export function DepCheck():Promise; diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js index 68518cd..159bb67 100644 --- a/frontend/wailsjs/go/backend_golang/App.js +++ b/frontend/wailsjs/go/backend_golang/App.js @@ -14,6 +14,10 @@ export function ConvertModel(arg1, arg2, arg3) { return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3); } +export function CopyFile(arg1, arg2) { + return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2); +} + export function DeleteFile(arg1) { return window['go']['backend_golang']['App']['DeleteFile'](arg1); }