add customCudaFile support
This commit is contained in:
		
							parent
							
								
									65d92d5da1
								
							
						
					
					
						commit
						4eca1537a7
					
				| @ -3,6 +3,7 @@ package backend_golang | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"path/filepath" | ||||
| @ -92,6 +93,26 @@ func (a *App) DeleteFile(path string) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (a *App) CopyFile(src string, dst string) error { | ||||
| 	sourceFile, err := os.Open(src) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer sourceFile.Close() | ||||
| 
 | ||||
| 	destFile, err := os.Create(dst) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer destFile.Close() | ||||
| 
 | ||||
| 	_, err = io.Copy(sourceFile, destFile) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (a *App) OpenFileFolder(path string) error { | ||||
| 	absPath, err := filepath.Abs(path) | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -99,5 +99,7 @@ | ||||
|   "Model Config Exception": "模型配置异常", | ||||
|   "Use Gitee Updates Source": "使用Gitee更新源", | ||||
|   "Use Custom CUDA kernel to Accelerate": "使用自定义CUDA算子加速", | ||||
|   "Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项" | ||||
|   "Enabling this option can greatly improve inference speed, but there may be compatibility issues. If it fails to start, please turn off this option.": "开启这个选项能大大提升推理速度,但可能存在兼容性问题,如果启动失败,请关闭此选项", | ||||
|   "Supported custom cuda file not found": "没有找到支持的自定义cuda文件", | ||||
|   "Failed to copy custom cuda file": "自定义cuda文件复制失败" | ||||
| } | ||||
| @ -2,6 +2,7 @@ import React, { FC, MouseEventHandler, ReactElement } from 'react'; | ||||
| import commonStore, { ModelStatus } from '../stores/commonStore'; | ||||
| import { | ||||
|   AddToDownloadList, | ||||
|   CopyFile, | ||||
|   DepCheck, | ||||
|   FileExists, | ||||
|   InstallPyDep, | ||||
| @ -12,7 +13,7 @@ import { observer } from 'mobx-react-lite'; | ||||
| import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis'; | ||||
| import { toast } from 'react-toastify'; | ||||
| import manifest from '../../../manifest.json'; | ||||
| import { getStrategy, saveCache, toastWithButton } from '../utils'; | ||||
| import { getStrategy, getSupportedCustomCudaFile, saveCache, toastWithButton } from '../utils'; | ||||
| import { useTranslation } from 'react-i18next'; | ||||
| import { ToolTipButton } from './ToolTipButton'; | ||||
| import { Play16Regular, Stop16Regular } from '@fluentui/react-icons'; | ||||
| @ -83,6 +84,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean | ||||
|           return; | ||||
|         } | ||||
|         commonStore.setDepComplete(true); | ||||
|         CopyFile('./backend-python/wkv_cuda_utils/wkv_cuda_model.py', './py310/Lib/site-packages/rwkv/model.py'); | ||||
|         saveCache(); | ||||
|       } | ||||
| 
 | ||||
| @ -132,10 +134,23 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean | ||||
|               presence_penalty: modelConfig.apiParameters.presencePenalty, | ||||
|               frequency_penalty: modelConfig.apiParameters.frequencyPenalty | ||||
|             }); | ||||
| 
 | ||||
|             let customCudaFile = ''; | ||||
|             if (modelConfig.modelParameters.useCustomCuda) { | ||||
|               customCudaFile = getSupportedCustomCudaFile(); | ||||
|               if (customCudaFile) | ||||
|                 await CopyFile(customCudaFile, './py310/Lib/site-packages/rwkv/wkv_cuda.pyd').catch(() => { | ||||
|                   customCudaFile = ''; | ||||
|                   toast(t('Failed to copy custom cuda file'), { type: 'error' }); | ||||
|                 }); | ||||
|               else | ||||
|                 toast(t('Supported custom cuda file not found'), { type: 'warning' }); | ||||
|             } | ||||
| 
 | ||||
|             switchModel({ | ||||
|               model: `${manifest.localModelDir}/${modelConfig.modelParameters.modelName}`, | ||||
|               strategy: getStrategy(modelConfig), | ||||
|               customCuda: !!modelConfig.modelParameters.useCustomCuda | ||||
|               customCuda: customCudaFile !== '' | ||||
|             }).then((r) => { | ||||
|               if (r.ok) { | ||||
|                 commonStore.setStatus({ modelStatus: ModelStatus.Working }); | ||||
|  | ||||
| @ -267,4 +267,13 @@ export function toastWithButton(text: string, buttonText: string, onClickButton: | ||||
|       type: 'info', | ||||
|       ...options | ||||
|     }); | ||||
| } | ||||
| 
 | ||||
| export function getSupportedCustomCudaFile() { | ||||
|   if ([' 10', ' 20', ' 30'].some(v => commonStore.status.device_name.includes(v))) | ||||
|     return './backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd'; | ||||
|   else if ([' 40'].some(v => commonStore.status.device_name.includes(v))) | ||||
|     return './backend-python/wkv_cuda_utils/wkv_cuda40.pyd'; | ||||
|   else | ||||
|     return ''; | ||||
| } | ||||
							
								
								
									
										2
									
								
								frontend/wailsjs/go/backend_golang/App.d.ts
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								frontend/wailsjs/go/backend_golang/App.d.ts
									
									
									
									
										vendored
									
									
								
							| @ -8,6 +8,8 @@ export function ContinueDownload(arg1:string):Promise<void>; | ||||
| 
 | ||||
| export function ConvertModel(arg1:string,arg2:string,arg3:string):Promise<string>; | ||||
| 
 | ||||
| export function CopyFile(arg1:string,arg2:string):Promise<void>; | ||||
| 
 | ||||
| export function DeleteFile(arg1:string):Promise<void>; | ||||
| 
 | ||||
| export function DepCheck():Promise<void>; | ||||
|  | ||||
| @ -14,6 +14,10 @@ export function ConvertModel(arg1, arg2, arg3) { | ||||
|   return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3); | ||||
| } | ||||
| 
 | ||||
| export function CopyFile(arg1, arg2) { | ||||
|   return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2); | ||||
| } | ||||
| 
 | ||||
| export function DeleteFile(arg1) { | ||||
|   return window['go']['backend_golang']['App']['DeleteFile'](arg1); | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user