import React, { FC, ReactElement, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@fluentui/react-components'; import { ConvertData, FileExists, MergeLora, OpenFileFolder, WslCommand, WslEnable, WslInstallUbuntu, WslIsEnabled, WslStart, WslStop } from '../../wailsjs/go/backend_golang/App'; import { toast } from 'react-toastify'; import commonStore from '../stores/commonStore'; import { observer } from 'mobx-react-lite'; import { SelectTabEventHandler } from '@fluentui/react-tabs'; import { checkDependencies, refreshLocalModels, toastWithButton } from '../utils'; import { Section } from '../components/Section'; import { Labeled } from '../components/Labeled'; import { ToolTipButton } from '../components/ToolTipButton'; import { DataUsageSettings20Regular, Folder20Regular } from '@fluentui/react-icons'; import { useNavigate } from 'react-router'; import { Precision } from './Configs'; import { CategoryScale, Chart as ChartJS, Legend, LinearScale, LineElement, PointElement, Title, Tooltip } from 'chart.js'; import { Line } from 'react-chartjs-2'; import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types'; import { WindowShow } from '../../wailsjs/runtime'; ChartJS.register( CategoryScale, LinearScale, PointElement, LineElement, Tooltip, Title, Legend ); const parseLossData = (data: string) => { const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\d+:\d+)<(\d+:\d+),\s+(\d+.\d+it\/s), loss=(\d+.\d+),[\s\S]*\]/g; const matches = Array.from(data.matchAll(regex)); if (matches.length === 0) return; const lastMatch = matches[matches.length - 1]; const epoch = parseInt(lastMatch[1]); const loss = parseFloat(lastMatch[8]); commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`); addLossDataToChart(epoch, loss); }; let chartLine: ChartJSOrUndefined<'line', (number | null)[], string>; const addLossDataToChart = (epoch: number, loss: number) => { const epochIndex = commonStore.chartData.labels!.findIndex(l => l.includes(epoch.toString())); if (epochIndex === -1) { if (epoch === 0) { commonStore.chartData.labels!.push('Init'); commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss]; } commonStore.chartData.labels!.push('Epoch ' + epoch.toString()); commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss]; } else { if (chartLine) { const newData = [...commonStore.chartData.datasets[0].data]; newData[epochIndex] = loss; chartLine.data.datasets[0].data = newData; chartLine.update(); } } commonStore.setChartData(commonStore.chartData); }; export type DataProcessParameters = { dataPath: string; vocabPath: string; } export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'fp32' | 'tf32'; export type LoraFinetuneParameters = { baseModel: string; ctxLen: number; epochSteps: number; epochCount: number; epochBegin: number; epochSave: number; microBsz: number; accumGradBatches: number; preFfn: boolean; headQk: boolean; lrInit: string; lrFinal: string; warmupSteps: number; beta1: number; beta2: number; adamEps: string; devices: number; precision: LoraFinetunePrecision; gradCp: boolean; loraR: number; loraAlpha: number; loraDropout: number; loraLoad: string } const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, type: string, name: string]> = [ ['devices', 'number', 'Devices'], ['precision', 'LoraFinetunePrecision', 'Precision'], ['gradCp', 'boolean', 'Gradient Checkpoint'], ['ctxLen', 'number', 'Context Length'], ['epochSteps', 'number', 'Epoch Steps'], ['epochCount', 'number', 'Epoch Count'], ['epochBegin', 'number', 'Epoch Begin'], ['epochSave', 'number', 'Epoch Save'], ['lrInit', 'string', 'Learning Rate Init'], ['lrFinal', 'string', 'Learning Rate Final'], ['microBsz', 'number', 'Micro Batch Size'], ['accumGradBatches', 'number', 'Accumulate Gradient Batches'], ['warmupSteps', 'number', 'Warmup Steps'], ['adamEps', 'string', 'Adam Epsilon'], ['beta1', 'number', 'Beta 1'], ['beta2', 'number', 'Beta 2'], ['loraR', 'number', 'LoRA R'], ['loraAlpha', 'number', 'LoRA Alpha'], ['loraDropout', 'number', 'LoRA Dropout'], ['beta1', 'any', ''], ['preFfn', 'boolean', 'Pre-FFN'], ['headQk', 'boolean', 'Head QK'] ]; export const wslHandler = (data: string) => { if (data) { addWslMessage(data); parseLossData(data); } }; const addWslMessage = (message: string) => { const newData = commonStore.wslStdout + '\n' + message; let lines = newData.split('\n'); const result = lines.slice(-100).join('\n'); commonStore.setWslStdout(result); }; const TerminalDisplay: FC = observer(() => { const bodyRef = useRef(null); const scrollToBottom = () => { if (bodyRef.current) bodyRef.current.scrollTop = bodyRef.current.scrollHeight; }; useEffect(() => { scrollToBottom(); }); return (
{commonStore.wslStdout}
); }); const Terminal: FC = observer(() => { const { t } = useTranslation(); const [input, setInput] = useState(''); const handleKeyDown = (e: any) => { e.stopPropagation(); if (e.keyCode === 13) { e.preventDefault(); if (!input) return; WslStart().then(() => { addWslMessage('WSL> ' + input); setInput(''); WslCommand(input).catch((e: any) => { toast((e.message || e), { type: 'error' }); }); }).catch((e: any) => { toast((e.message || e), { type: 'error' }); }); } }; return (
WSL: { setInput(e.target.value); }} onKeyDown={handleKeyDown}>
); }); const LoraFinetune: FC = observer(() => { const { t } = useTranslation(); const navigate = useNavigate(); const chartRef = useRef>(null); const dataParams = commonStore.dataProcessParams; const loraParams = commonStore.loraFinetuneParams; if (chartRef.current) chartLine = chartRef.current; const setDataParams = (newParams: Partial) => { commonStore.setDataProcessParams({ ...dataParams, ...newParams }); }; const setLoraParams = (newParams: Partial) => { commonStore.setLoraFinetuneParameters({ ...loraParams, ...newParams }); }; useEffect(() => { if (loraParams.baseModel === '') setLoraParams({ baseModel: commonStore.modelSourceList.find(m => m.isComplete)?.name || '' }); }, []); const StartLoraFinetune = async () => { const ok = await checkDependencies(navigate); if (!ok) return; const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`; if (!await FileExists(convertedDataPath + '.idx')) { toast(t('Please convert data first.'), { type: 'error' }); return; } WslIsEnabled().then(() => { WslStart().then(() => { setTimeout(WindowShow, 1000); let ctxLen = loraParams.ctxLen; if (dataParams.dataPath === 'finetune/data/sample.jsonl') { ctxLen = 150; toast(t('You are using sample data for training. For formal training, please make sure to create your own jsonl file.'), { type: 'info', autoClose: 6000 }); } commonStore.setChartData({ labels: [], datasets: [ { label: 'Loss', data: [], borderColor: 'rgb(53, 162, 235)', backgroundColor: 'rgba(53, 162, 235, 0.5)' } ] }); WslCommand(`export cnMirror=${commonStore.settings.cnMirror ? '1' : '0'} ` + `&& export loadModel=models/${loraParams.baseModel} ` + `&& chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` + (loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') + (loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') + `--data_file ${convertedDataPath} ` + `--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` + `--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` + `--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` + `--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` + `--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` + `--warmup_steps ${loraParams.warmupSteps} ` + `--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` + `--devices ${loraParams.devices} --precision ${loraParams.precision} ` + `--grad_cp ${loraParams.gradCp ? '1' : '0'} ` + `--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e: any) => { toast((e.message || e), { type: 'error' }); }); }).catch(e => { const msg = e.message || e; if (msg === 'ubuntu not found') { WindowShow(); toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => { WslInstallUbuntu().then(() => { WindowShow(); toast(t('Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button'), { type: 'info', autoClose: 10000 }); }); }); } }); }).catch(e => { const msg = e.message || e; const enableWsl = (forceMode: boolean) => { WindowShow(); toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => { WslEnable(forceMode).then(() => { WindowShow(); toast(t('After installation, please restart your computer to enable WSL'), { type: 'info', autoClose: false }); }).catch(e => { toast((e.message || e), { type: 'error' }); }); }); }; if (msg === 'wsl is not enabled') { enableWsl(false); } else if (msg.includes('wsl.state: The system cannot find the file')) { enableWsl(true); } else { toast(msg, { type: 'error' }); } }); }; return (
{(commonStore.wslStdout.length > 0 || commonStore.chartData.labels!.length !== 0) &&
{commonStore.wslStdout.length > 0 && commonStore.chartData.labels!.length === 0 && } {commonStore.chartData.labels!.length !== 0 && }
}
{ setDataParams({ dataPath: data.value }); }} /> } onClick={() => { OpenFileFolder(dataParams.dataPath, false); }} />
} />
{t('Vocab Path')} { setDataParams({ vocabPath: data.value }); }} />
} />
{t('Base Model')} } onClick={() => { navigate({ pathname: '/models' }); }} />
{t('LoRA Model')}
{ loraFinetuneParametersOptions.map(([key, type, name], index) => { return ( { setLoraParams({ [key]: Number(data.value) }); }} /> : type === 'boolean' ? { setLoraParams({ [key]: data.checked }); }} /> : type === 'string' ? { setLoraParams({ [key]: data.value }); }} /> : type === 'LoraFinetunePrecision' ? { if (data.optionText) { setLoraParams({ precision: data.optionText as LoraFinetunePrecision }); } }} > :
} /> ); }) }
} />
); }); type TrainNavigationItem = { element: ReactElement; }; const pages: { [label: string]: TrainNavigationItem } = { 'LoRA Finetune': { element: }, WSL: { element: } }; export const Train: FC = () => { const { t } = useTranslation(); const [tab, setTab] = useState('LoRA Finetune'); const selectTab: SelectTabEventHandler = (e, data) => typeof data.value === 'string' ? setTab(data.value) : null; return
{Object.entries(pages).map(([label]) => ( {t(label)} ))}
{pages[tab].element}
; };