diff --git a/backend-golang/app.go b/backend-golang/app.go index 057d6e5..0a7a1ee 100644 --- a/backend-golang/app.go +++ b/backend-golang/app.go @@ -46,6 +46,7 @@ func (a *App) OnStartup(ctx context.Context) { } os.Chmod(a.exDir+"backend-rust/webgpu_server", 0777) + os.Chmod(a.exDir+"backend-rust/web-rwkv-converter", 0777) os.Mkdir(a.exDir+"models", os.ModePerm) os.Mkdir(a.exDir+"lora-models", os.ModePerm) os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm) diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index 78e0006..482bd9a 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -46,15 +46,10 @@ func (a *App) ConvertModel(python string, modelPath string, strategy string, out return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy) } -func (a *App) ConvertSafetensors(python string, modelPath string, outPath string) (string, error) { - var err error - if python == "" { - python, err = GetPython() - } - if err != nil { - return "", err - } - return Cmd(python, "./backend-python/convert_safetensors.py", "--input", modelPath, "--output", outPath) +func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, error) { + args := []string{"./backend-rust/web-rwkv-converter"} + args = append(args, "--input", modelPath, "--output", outPath) + return Cmd(args...) } func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) { diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index e330c62..11d4eb9 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -90,7 +90,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean return; } else { toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => { - convertToSt(navigate, modelConfig); + convertToSt(modelConfig); }); commonStore.setStatus({ status: ModelStatus.Offline }); return; diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index 51ef338..99bf663 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -273,7 +273,7 @@ const Configs: FC = observer(() => { }} /> : convertToSt(navigate, selectedConfig)} /> + onClick={() => convertToSt(selectedConfig)} /> } { 'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad, outputPath).then(async () => { if (!await FileExists(outputPath)) { - toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' }); + if (commonStore.platform === 'windows' || commonStore.platform === 'linux') + toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' }); } else { toast(t('Merge model successfully'), { type: 'success' }); } diff --git a/frontend/src/utils/convert-to-st.ts b/frontend/src/utils/convert-to-st.ts index 5b16ba5..32cf68b 100644 --- a/frontend/src/utils/convert-to-st.ts +++ b/frontend/src/utils/convert-to-st.ts @@ -1,29 +1,18 @@ import { toast } from 'react-toastify'; import commonStore from '../stores/commonStore'; import { t } from 'i18next'; -import { checkDependencies } from './index'; import { ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App'; import { WindowShow } from '../../wailsjs/runtime'; -import { NavigateFunction } from 'react-router'; import { ModelConfig } from '../types/configs'; -export const convertToSt = async (navigate: NavigateFunction, selectedConfig: ModelConfig) => { - if (commonStore.platform === 'linux') { - toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_safetensors.py)', { type: 'info' }); - return; - } - - const ok = await checkDependencies(navigate); - if (!ok) - return; - +export const convertToSt = async (selectedConfig: ModelConfig) => { const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; if (await FileExists(modelPath)) { - toast(t('Start Converting'), { autoClose: 1000, type: 'info' }); + toast(t('Start Converting'), { autoClose: 2000, type: 'info' }); const newModelPath = modelPath.replace(/\.pth$/, '.st'); - ConvertSafetensors(commonStore.settings.customPythonPath, modelPath, newModelPath).then(async () => { + ConvertSafetensors(modelPath, newModelPath).then(async () => { if (!await FileExists(newModelPath)) { - if (commonStore.platform === 'windows') + if (commonStore.platform === 'windows' || commonStore.platform === 'linux') toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); } else { toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts index fe79aaf..de1ff57 100755 --- a/frontend/wailsjs/go/backend_golang/App.d.ts +++ b/frontend/wailsjs/go/backend_golang/App.d.ts @@ -12,7 +12,7 @@ export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Pro export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise; -export function ConvertSafetensors(arg1:string,arg2:string,arg3:string):Promise; +export function ConvertSafetensors(arg1:string,arg2:string):Promise; export function CopyFile(arg1:string,arg2:string):Promise; diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js index 9173a12..e5d93c5 100755 --- a/frontend/wailsjs/go/backend_golang/App.js +++ b/frontend/wailsjs/go/backend_golang/App.js @@ -22,8 +22,8 @@ export function ConvertModel(arg1, arg2, arg3, arg4) { return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4); } -export function ConvertSafetensors(arg1, arg2, arg3) { - return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2, arg3); +export function ConvertSafetensors(arg1, arg2) { + return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2); } export function CopyFile(arg1, arg2) {