improve lora finetune process (need to be refactored)

This commit is contained in:
josc146
2023-07-03 21:40:16 +08:00
parent 134b2884e6
commit 76761ee453
8 changed files with 160 additions and 112 deletions

View File

@@ -17,7 +17,7 @@ import { toast } from 'react-toastify';
import commonStore from '../stores/commonStore';
import { observer } from 'mobx-react-lite';
import { SelectTabEventHandler } from '@fluentui/react-tabs';
import { refreshLocalModels, toastWithButton } from '../utils';
import { checkDependencies, refreshLocalModels, toastWithButton } from '../utils';
import { Section } from '../components/Section';
import { Labeled } from '../components/Labeled';
import { ToolTipButton } from '../components/ToolTipButton';
@@ -36,6 +36,7 @@ import {
} from 'chart.js';
import { Line } from 'react-chartjs-2';
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
import { WindowShow } from '../../wailsjs/runtime';
ChartJS.register(
CategoryScale,
@@ -187,10 +188,10 @@ const Terminal: FC = observer(() => {
WslStart().then(() => {
addWslMessage('WSL> ' + input);
setInput('');
WslCommand(input).catch((e) => {
WslCommand(input).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
}).catch((e) => {
}).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
}
@@ -207,7 +208,7 @@ const Terminal: FC = observer(() => {
<Button onClick={() => {
WslStop().then(() => {
toast(t('Command Stopped'), { type: 'success' });
}).catch((e) => {
}).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
}}>
@@ -250,13 +251,28 @@ const LoraFinetune: FC = observer(() => {
});
}, []);
const StartLoraFinetune = () => {
const StartLoraFinetune = async () => {
const ok = await checkDependencies(navigate);
if (!ok)
return;
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
if (!await FileExists(convertedDataPath + '.idx')) {
toast(t('Please convert data first.'), { type: 'error' });
return;
}
WslIsEnabled().then(() => {
WslStart().then(async () => {
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
if (!await FileExists(convertedDataPath + '.idx')) {
toast(t('Please convert data first.'), { type: 'error' });
return;
WslStart().then(() => {
setTimeout(WindowShow, 1000);
let ctxLen = loraParams.ctxLen;
if (dataParams.dataPath === 'finetune/data/sample.jsonl') {
ctxLen = 150;
toast(t('You are using sample data for training. For formal training, please make sure to create your own jsonl file.'), {
type: 'info',
autoClose: 6000
});
}
commonStore.setChartData({
@@ -277,7 +293,7 @@ const LoraFinetune: FC = observer(() => {
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
`--data_file ${convertedDataPath} ` +
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
`--ctx_len ${loraParams.ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
`--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` +
@@ -285,15 +301,20 @@ const LoraFinetune: FC = observer(() => {
`--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` +
`--devices ${loraParams.devices} --precision ${loraParams.precision} ` +
`--grad_cp ${loraParams.gradCp ? '1' : '0'} ` +
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e) => {
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
}).catch(e => {
const msg = e.message || e;
if (msg === 'ubuntu not found') {
WindowShow();
toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => {
WslInstallUbuntu().then(() => {
toast(t('Please install Ubuntu using Microsoft Store'), { type: 'info', autoClose: 6000 });
WindowShow();
toast(t('Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button'), {
type: 'info',
autoClose: 10000
});
});
});
}
@@ -302,8 +323,10 @@ const LoraFinetune: FC = observer(() => {
const msg = e.message || e;
const enableWsl = (forceMode: boolean) => {
WindowShow();
toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => {
WslEnable(forceMode).then(() => {
WindowShow();
toast(t('After installation, please restart your computer to enable WSL'), {
type: 'info',
autoClose: false
@@ -380,7 +403,7 @@ const LoraFinetune: FC = observer(() => {
'./finetune/json2binidx_tool/data/' + dataParams.dataPath.split('/').pop()!.split('.')[0],
dataParams.vocabPath).then(() => {
toast(t('Convert Data successfully'), { type: 'success' });
}).catch((e) => {
}).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
}}>{t('Convert')}</Button>
@@ -424,15 +447,22 @@ const LoraFinetune: FC = observer(() => {
<option key={index} value={name}>{name}</option>
)}
</Select>
<Button onClick={() => {
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
toast(t('Merge model successfully'), { type: 'success' });
refreshLocalModels({ models: commonStore.modelSourceList }, false);
}).catch((e) => {
toast((e.message || e), { type: 'error' });
});
<Button onClick={async () => {
const ok = await checkDependencies(navigate);
if (!ok)
return;
if (loraParams.loraLoad) {
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
toast(t('Merge model successfully'), { type: 'success' });
refreshLocalModels({ models: commonStore.modelSourceList }, false);
}).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
} else {
toast(t('Please select a LoRA model'), { type: 'info' });
}
}}>{t('Merge Model')}</Button>
</div>
{
@@ -491,7 +521,7 @@ const LoraFinetune: FC = observer(() => {
<Button appearance="secondary" size="large" onClick={() => {
WslStop().then(() => {
toast(t('Command Stopped'), { type: 'success' });
}).catch((e) => {
}).catch((e: any) => {
toast((e.message || e), { type: 'error' });
});
}}>{t('Stop')}</Button>