diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index c399738..e47dde1 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -113,3 +113,11 @@ func (a *App) InstallPyDep(python string, cnMirror bool) (string, error) { return Cmd(python, "-m", "pip", "install", "-r", "./backend-python/requirements_without_cyac.txt") } } + +func (a *App) GetPyError() string { + content, err := os.ReadFile("./error.txt") + if err != nil { + return "" + } + return string(content) +} diff --git a/backend-python/convert_model.py b/backend-python/convert_model.py index 66661bd..bfba74b 100644 --- a/backend-python/convert_model.py +++ b/backend-python/convert_model.py @@ -219,13 +219,17 @@ def get_args(): return p.parse_args() -args = get_args() -if not args.quiet: - print(f"** {args}") +try: + args = get_args() + if not args.quiet: + print(f"** {args}") -RWKV( - getattr(args, "in"), - args.strategy, - verbose=not args.quiet, - convert_and_save_and_exit=args.out, -) + RWKV( + getattr(args, "in"), + args.strategy, + verbose=not args.quiet, + convert_and_save_and_exit=args.out, + ) +except Exception as e: + with open("error.txt", "w") as f: + f.write(str(e)) diff --git a/finetune/json2binidx_tool/tools/preprocess_data.py b/finetune/json2binidx_tool/tools/preprocess_data.py index b276494..932bb21 100644 --- a/finetune/json2binidx_tool/tools/preprocess_data.py +++ b/finetune/json2binidx_tool/tools/preprocess_data.py @@ -243,4 +243,8 @@ def main(): if __name__ == "__main__": - main() + try: + main() + except Exception as e: + with open("error.txt", "w") as f: + f.write(str(e)) diff --git a/finetune/lora/merge_lora.py b/finetune/lora/merge_lora.py index e43141b..f2be275 100644 --- a/finetune/lora/merge_lora.py +++ b/finetune/lora/merge_lora.py @@ -5,49 +5,64 @@ from typing import Dict import typing import torch -if '-h' in sys.argv or '--help' in sys.argv: - print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') +try: + if "-h" in sys.argv or "--help" in sys.argv: + print( + f"Usage: python3 {sys.argv[0]} [--use-gpu] " + ) -if sys.argv[1] == '--use-gpu': - device = 'cuda' - lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5] -else: - device = 'cpu' - lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4] + if sys.argv[1] == "--use-gpu": + device = "cuda" + lora_alpha, base_model, lora, output = ( + float(sys.argv[2]), + sys.argv[3], + sys.argv[4], + sys.argv[5], + ) + else: + device = "cpu" + lora_alpha, base_model, lora, output = ( + float(sys.argv[1]), + sys.argv[2], + sys.argv[3], + sys.argv[4], + ) + with torch.no_grad(): + w: Dict[str, torch.Tensor] = torch.load(base_model, map_location="cpu") + # merge LoRA-only slim checkpoint into the main weights + w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location="cpu") + for k in w_lora.keys(): + w[k] = w_lora[k] + output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() + # merge LoRA weights + keys = list(w.keys()) + for k in keys: + if k.endswith(".weight"): + prefix = k[: -len(".weight")] + lora_A = prefix + ".lora_A" + lora_B = prefix + ".lora_B" + if lora_A in keys: + assert lora_B in keys + print(f"merging {lora_A} and {lora_B} into {k}") + assert w[lora_B].shape[1] == w[lora_A].shape[0] + lora_r = w[lora_B].shape[1] + w[k] = w[k].to(device=device) + w[lora_A] = w[lora_A].to(device=device) + w[lora_B] = w[lora_B].to(device=device) + w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) + output_w[k] = w[k].to(device="cpu", copy=True) + del w[k] + del w[lora_A] + del w[lora_B] + continue -with torch.no_grad(): - w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') - # merge LoRA-only slim checkpoint into the main weights - w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') - for k in w_lora.keys(): - w[k] = w_lora[k] - output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() - # merge LoRA weights - keys = list(w.keys()) - for k in keys: - if k.endswith('.weight'): - prefix = k[:-len('.weight')] - lora_A = prefix + '.lora_A' - lora_B = prefix + '.lora_B' - if lora_A in keys: - assert lora_B in keys - print(f'merging {lora_A} and {lora_B} into {k}') - assert w[lora_B].shape[1] == w[lora_A].shape[0] - lora_r = w[lora_B].shape[1] - w[k] = w[k].to(device=device) - w[lora_A] = w[lora_A].to(device=device) - w[lora_B] = w[lora_B].to(device=device) - w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) - output_w[k] = w[k].to(device='cpu', copy=True) + if "lora" not in k: + print(f"retaining {k}") + output_w[k] = w[k].clone() del w[k] - del w[lora_A] - del w[lora_B] - continue - if 'lora' not in k: - print(f'retaining {k}') - output_w[k] = w[k].clone() - del w[k] - - torch.save(output_w, output) + torch.save(output_w, output) +except Exception as e: + with open("error.txt", "w") as f: + f.write(str(e)) diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index 13b0c6b..f715244 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -229,5 +229,7 @@ "VRAM is not enough": "显存不足", "Training data is not enough, reduce context length or add more data for training": "训练数据不足,请减小上下文长度或增加训练数据", "You are using WSL 1 for training, please upgrade to WSL 2. e.g. Run \"wsl --set-version Ubuntu-22.04 2\"": "你正在使用WSL 1进行训练,请升级到WSL 2。例如,运行\"wsl --set-version Ubuntu-22.04 2\"", - "Matched CUDA is not installed": "未安装匹配的CUDA" + "Matched CUDA is not installed": "未安装匹配的CUDA", + "Failed to convert data": "数据转换失败", + "Failed to merge model": "合并模型失败" } \ No newline at end of file diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index ac77a65..b06dd7d 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -13,8 +13,8 @@ import { Page } from '../components/Page'; import { useNavigate } from 'react-router'; import { RunButton } from '../components/RunButton'; import { updateConfig } from '../apis'; -import { ConvertModel, FileExists } from '../../wailsjs/go/backend_golang/App'; -import { getStrategy, refreshLocalModels } from '../utils'; +import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App'; +import { getStrategy } from '../utils'; import { useTranslation } from 'react-i18next'; import { WindowShow } from '../../wailsjs/runtime/runtime'; import strategyImg from '../assets/images/strategy.jpg'; @@ -253,9 +253,12 @@ export const Configs: FC = observer(() => { const strategy = getStrategy(selectedConfig); const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-'); toast(t('Start Converting'), { autoClose: 1000, type: 'info' }); - ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(() => { - toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); - refreshLocalModels({ models: commonStore.modelSourceList }, false); + ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => { + if (!await FileExists(newModelPath)) { + toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); + } else { + toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' }); + } }).catch(e => { const errMsg = e.message || e; if (errMsg.includes('path contains space')) diff --git a/frontend/src/pages/Train.tsx b/frontend/src/pages/Train.tsx index 8a20d1e..ce216a0 100644 --- a/frontend/src/pages/Train.tsx +++ b/frontend/src/pages/Train.tsx @@ -4,6 +4,7 @@ import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@ import { ConvertData, FileExists, + GetPyError, MergeLora, OpenFileFolder, WslCommand, @@ -17,7 +18,7 @@ import { toast } from 'react-toastify'; import commonStore from '../stores/commonStore'; import { observer } from 'mobx-react-lite'; import { SelectTabEventHandler } from '@fluentui/react-tabs'; -import { checkDependencies, refreshLocalModels, toastWithButton } from '../utils'; +import { checkDependencies, toastWithButton } from '../utils'; import { Section } from '../components/Section'; import { Labeled } from '../components/Labeled'; import { ToolTipButton } from '../components/ToolTipButton'; @@ -421,10 +422,13 @@ const LoraFinetune: FC = observer(() => { const ok = await checkDependencies(navigate); if (!ok) return; - ConvertData(commonStore.settings.customPythonPath, dataParams.dataPath, - './finetune/json2binidx_tool/data/' + dataParams.dataPath.split(/[\/\\]/).pop()!.split('.')[0], - dataParams.vocabPath).then(() => { - toast(t('Convert Data successfully'), { type: 'success' }); + const outputPrefix = './finetune/json2binidx_tool/data/' + dataParams.dataPath.split(/[\/\\]/).pop()!.split('.')[0]; + ConvertData(commonStore.settings.customPythonPath, dataParams.dataPath, outputPrefix, dataParams.vocabPath).then(async () => { + if (!await FileExists(outputPrefix + '_text_document.idx')) { + toast(t('Failed to convert data') + ' - ' + await GetPyError(), { type: 'error' }); + } else { + toast(t('Convert Data successfully'), { type: 'success' }); + } }).catch(showError); }}>{t('Convert')} @@ -472,11 +476,15 @@ const LoraFinetune: FC = observer(() => { if (!ok) return; if (loraParams.loraLoad) { + const outputPath = `models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`; MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha, 'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad, - `models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => { - toast(t('Merge model successfully'), { type: 'success' }); - refreshLocalModels({ models: commonStore.modelSourceList }, false); + outputPath).then(async () => { + if (!await FileExists(outputPath)) { + toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' }); + } else { + toast(t('Merge model successfully'), { type: 'success' }); + } }).catch(showError); } else { toast(t('Please select a LoRA model'), { type: 'info' }); diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts index 172529a..024ba7a 100755 --- a/frontend/wailsjs/go/backend_golang/App.d.ts +++ b/frontend/wailsjs/go/backend_golang/App.d.ts @@ -22,6 +22,8 @@ export function FileExists(arg1:string):Promise; export function GetPlatform():Promise; +export function GetPyError():Promise; + export function InstallPyDep(arg1:string,arg2:boolean):Promise; export function ListDirFiles(arg1:string):Promise>; diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js index d1c629e..63f79e2 100755 --- a/frontend/wailsjs/go/backend_golang/App.js +++ b/frontend/wailsjs/go/backend_golang/App.js @@ -42,6 +42,10 @@ export function GetPlatform() { return window['go']['backend_golang']['App']['GetPlatform'](); } +export function GetPyError() { + return window['go']['backend_golang']['App']['GetPyError'](); +} + export function InstallPyDep(arg1, arg2) { return window['go']['backend_golang']['App']['InstallPyDep'](arg1, arg2); }