improve lora finetune process (need to be refactored)
This commit is contained in:
@@ -17,7 +17,7 @@ import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { observer } from 'mobx-react-lite';
|
||||
import { SelectTabEventHandler } from '@fluentui/react-tabs';
|
||||
import { refreshLocalModels, toastWithButton } from '../utils';
|
||||
import { checkDependencies, refreshLocalModels, toastWithButton } from '../utils';
|
||||
import { Section } from '../components/Section';
|
||||
import { Labeled } from '../components/Labeled';
|
||||
import { ToolTipButton } from '../components/ToolTipButton';
|
||||
@@ -36,6 +36,7 @@ import {
|
||||
} from 'chart.js';
|
||||
import { Line } from 'react-chartjs-2';
|
||||
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
|
||||
ChartJS.register(
|
||||
CategoryScale,
|
||||
@@ -187,10 +188,10 @@ const Terminal: FC = observer(() => {
|
||||
WslStart().then(() => {
|
||||
addWslMessage('WSL> ' + input);
|
||||
setInput('');
|
||||
WslCommand(input).catch((e) => {
|
||||
WslCommand(input).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}).catch((e) => {
|
||||
}).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}
|
||||
@@ -207,7 +208,7 @@ const Terminal: FC = observer(() => {
|
||||
<Button onClick={() => {
|
||||
WslStop().then(() => {
|
||||
toast(t('Command Stopped'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
}).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>
|
||||
@@ -250,13 +251,28 @@ const LoraFinetune: FC = observer(() => {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const StartLoraFinetune = () => {
|
||||
const StartLoraFinetune = async () => {
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
|
||||
if (!await FileExists(convertedDataPath + '.idx')) {
|
||||
toast(t('Please convert data first.'), { type: 'error' });
|
||||
return;
|
||||
}
|
||||
|
||||
WslIsEnabled().then(() => {
|
||||
WslStart().then(async () => {
|
||||
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
|
||||
if (!await FileExists(convertedDataPath + '.idx')) {
|
||||
toast(t('Please convert data first.'), { type: 'error' });
|
||||
return;
|
||||
WslStart().then(() => {
|
||||
setTimeout(WindowShow, 1000);
|
||||
|
||||
let ctxLen = loraParams.ctxLen;
|
||||
if (dataParams.dataPath === 'finetune/data/sample.jsonl') {
|
||||
ctxLen = 150;
|
||||
toast(t('You are using sample data for training. For formal training, please make sure to create your own jsonl file.'), {
|
||||
type: 'info',
|
||||
autoClose: 6000
|
||||
});
|
||||
}
|
||||
|
||||
commonStore.setChartData({
|
||||
@@ -277,7 +293,7 @@ const LoraFinetune: FC = observer(() => {
|
||||
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
||||
`--data_file ${convertedDataPath} ` +
|
||||
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
|
||||
`--ctx_len ${loraParams.ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
||||
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
||||
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
||||
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
||||
`--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` +
|
||||
@@ -285,15 +301,20 @@ const LoraFinetune: FC = observer(() => {
|
||||
`--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` +
|
||||
`--devices ${loraParams.devices} --precision ${loraParams.precision} ` +
|
||||
`--grad_cp ${loraParams.gradCp ? '1' : '0'} ` +
|
||||
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e) => {
|
||||
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}).catch(e => {
|
||||
const msg = e.message || e;
|
||||
if (msg === 'ubuntu not found') {
|
||||
WindowShow();
|
||||
toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => {
|
||||
WslInstallUbuntu().then(() => {
|
||||
toast(t('Please install Ubuntu using Microsoft Store'), { type: 'info', autoClose: 6000 });
|
||||
WindowShow();
|
||||
toast(t('Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button'), {
|
||||
type: 'info',
|
||||
autoClose: 10000
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -302,8 +323,10 @@ const LoraFinetune: FC = observer(() => {
|
||||
const msg = e.message || e;
|
||||
|
||||
const enableWsl = (forceMode: boolean) => {
|
||||
WindowShow();
|
||||
toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => {
|
||||
WslEnable(forceMode).then(() => {
|
||||
WindowShow();
|
||||
toast(t('After installation, please restart your computer to enable WSL'), {
|
||||
type: 'info',
|
||||
autoClose: false
|
||||
@@ -380,7 +403,7 @@ const LoraFinetune: FC = observer(() => {
|
||||
'./finetune/json2binidx_tool/data/' + dataParams.dataPath.split('/').pop()!.split('.')[0],
|
||||
dataParams.vocabPath).then(() => {
|
||||
toast(t('Convert Data successfully'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
}).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Convert')}</Button>
|
||||
@@ -424,15 +447,22 @@ const LoraFinetune: FC = observer(() => {
|
||||
<option key={index} value={name}>{name}</option>
|
||||
)}
|
||||
</Select>
|
||||
<Button onClick={() => {
|
||||
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
|
||||
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
|
||||
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
|
||||
toast(t('Merge model successfully'), { type: 'success' });
|
||||
refreshLocalModels({ models: commonStore.modelSourceList }, false);
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
<Button onClick={async () => {
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
if (loraParams.loraLoad) {
|
||||
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
|
||||
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
|
||||
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
|
||||
toast(t('Merge model successfully'), { type: 'success' });
|
||||
refreshLocalModels({ models: commonStore.modelSourceList }, false);
|
||||
}).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
} else {
|
||||
toast(t('Please select a LoRA model'), { type: 'info' });
|
||||
}
|
||||
}}>{t('Merge Model')}</Button>
|
||||
</div>
|
||||
{
|
||||
@@ -491,7 +521,7 @@ const LoraFinetune: FC = observer(() => {
|
||||
<Button appearance="secondary" size="large" onClick={() => {
|
||||
WslStop().then(() => {
|
||||
toast(t('Command Stopped'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
}).catch((e: any) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Stop')}</Button>
|
||||
|
||||
Reference in New Issue
Block a user