2023-07-03 09:41:47 +00:00
|
|
|
import React, { FC, ReactElement, useEffect, useRef, useState } from 'react';
|
2023-05-22 02:52:06 +00:00
|
|
|
import { useTranslation } from 'react-i18next';
|
2023-07-03 09:41:47 +00:00
|
|
|
import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@fluentui/react-components';
|
|
|
|
import {
|
|
|
|
ConvertData,
|
|
|
|
FileExists,
|
|
|
|
MergeLora,
|
|
|
|
OpenFileFolder,
|
|
|
|
WslCommand,
|
|
|
|
WslEnable,
|
|
|
|
WslInstallUbuntu,
|
|
|
|
WslIsEnabled,
|
|
|
|
WslStart,
|
|
|
|
WslStop
|
|
|
|
} from '../../wailsjs/go/backend_golang/App';
|
|
|
|
import { toast } from 'react-toastify';
|
|
|
|
import commonStore from '../stores/commonStore';
|
|
|
|
import { observer } from 'mobx-react-lite';
|
|
|
|
import { SelectTabEventHandler } from '@fluentui/react-tabs';
|
2023-07-03 13:40:16 +00:00
|
|
|
import { checkDependencies, refreshLocalModels, toastWithButton } from '../utils';
|
2023-07-03 09:41:47 +00:00
|
|
|
import { Section } from '../components/Section';
|
|
|
|
import { Labeled } from '../components/Labeled';
|
|
|
|
import { ToolTipButton } from '../components/ToolTipButton';
|
|
|
|
import { DataUsageSettings20Regular, Folder20Regular } from '@fluentui/react-icons';
|
|
|
|
import { useNavigate } from 'react-router';
|
|
|
|
import { Precision } from './Configs';
|
|
|
|
import {
|
|
|
|
CategoryScale,
|
|
|
|
Chart as ChartJS,
|
|
|
|
Legend,
|
|
|
|
LinearScale,
|
|
|
|
LineElement,
|
|
|
|
PointElement,
|
|
|
|
Title,
|
|
|
|
Tooltip
|
|
|
|
} from 'chart.js';
|
|
|
|
import { Line } from 'react-chartjs-2';
|
|
|
|
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
|
2023-07-03 13:40:16 +00:00
|
|
|
import { WindowShow } from '../../wailsjs/runtime';
|
2023-05-05 15:23:34 +00:00
|
|
|
|
2023-07-03 09:41:47 +00:00
|
|
|
ChartJS.register(
|
|
|
|
CategoryScale,
|
|
|
|
LinearScale,
|
|
|
|
PointElement,
|
|
|
|
LineElement,
|
|
|
|
Tooltip,
|
|
|
|
Title,
|
|
|
|
Legend
|
|
|
|
);
|
|
|
|
|
|
|
|
const parseLossData = (data: string) => {
|
|
|
|
const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\d+:\d+)<(\d+:\d+),\s+(\d+.\d+it\/s), loss=(\d+.\d+),[\s\S]*\]/g;
|
|
|
|
const matches = Array.from(data.matchAll(regex));
|
|
|
|
if (matches.length === 0)
|
|
|
|
return;
|
|
|
|
const lastMatch = matches[matches.length - 1];
|
|
|
|
const epoch = parseInt(lastMatch[1]);
|
|
|
|
const loss = parseFloat(lastMatch[8]);
|
|
|
|
commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`);
|
|
|
|
addLossDataToChart(epoch, loss);
|
|
|
|
};
|
|
|
|
|
|
|
|
let chartLine: ChartJSOrUndefined<'line', (number | null)[], string>;
|
|
|
|
|
|
|
|
const addLossDataToChart = (epoch: number, loss: number) => {
|
|
|
|
const epochIndex = commonStore.chartData.labels!.findIndex(l => l.includes(epoch.toString()));
|
|
|
|
if (epochIndex === -1) {
|
|
|
|
if (epoch === 0) {
|
|
|
|
commonStore.chartData.labels!.push('Init');
|
|
|
|
commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss];
|
|
|
|
}
|
|
|
|
commonStore.chartData.labels!.push('Epoch ' + epoch.toString());
|
|
|
|
commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss];
|
|
|
|
} else {
|
|
|
|
if (chartLine) {
|
|
|
|
const newData = [...commonStore.chartData.datasets[0].data];
|
|
|
|
newData[epochIndex] = loss;
|
|
|
|
chartLine.data.datasets[0].data = newData;
|
|
|
|
chartLine.update();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
commonStore.setChartData(commonStore.chartData);
|
|
|
|
};
|
|
|
|
|
|
|
|
export type DataProcessParameters = {
|
|
|
|
dataPath: string;
|
|
|
|
vocabPath: string;
|
|
|
|
}
|
|
|
|
|
|
|
|
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'fp32' | 'tf32';
|
|
|
|
|
|
|
|
export type LoraFinetuneParameters = {
|
|
|
|
baseModel: string;
|
|
|
|
ctxLen: number;
|
|
|
|
epochSteps: number;
|
|
|
|
epochCount: number;
|
|
|
|
epochBegin: number;
|
|
|
|
epochSave: number;
|
|
|
|
microBsz: number;
|
|
|
|
accumGradBatches: number;
|
|
|
|
preFfn: boolean;
|
|
|
|
headQk: boolean;
|
|
|
|
lrInit: string;
|
|
|
|
lrFinal: string;
|
|
|
|
warmupSteps: number;
|
|
|
|
beta1: number;
|
|
|
|
beta2: number;
|
|
|
|
adamEps: string;
|
|
|
|
devices: number;
|
|
|
|
precision: LoraFinetunePrecision;
|
|
|
|
gradCp: boolean;
|
|
|
|
loraR: number;
|
|
|
|
loraAlpha: number;
|
|
|
|
loraDropout: number;
|
|
|
|
loraLoad: string
|
|
|
|
}
|
|
|
|
|
|
|
|
const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, type: string, name: string]> = [
|
|
|
|
['devices', 'number', 'Devices'],
|
|
|
|
['precision', 'LoraFinetunePrecision', 'Precision'],
|
|
|
|
['gradCp', 'boolean', 'Gradient Checkpoint'],
|
|
|
|
['ctxLen', 'number', 'Context Length'],
|
|
|
|
['epochSteps', 'number', 'Epoch Steps'],
|
|
|
|
['epochCount', 'number', 'Epoch Count'],
|
|
|
|
['epochBegin', 'number', 'Epoch Begin'],
|
|
|
|
['epochSave', 'number', 'Epoch Save'],
|
|
|
|
['lrInit', 'string', 'Learning Rate Init'],
|
|
|
|
['lrFinal', 'string', 'Learning Rate Final'],
|
|
|
|
['microBsz', 'number', 'Micro Batch Size'],
|
|
|
|
['accumGradBatches', 'number', 'Accumulate Gradient Batches'],
|
|
|
|
['warmupSteps', 'number', 'Warmup Steps'],
|
|
|
|
['adamEps', 'string', 'Adam Epsilon'],
|
|
|
|
['beta1', 'number', 'Beta 1'],
|
|
|
|
['beta2', 'number', 'Beta 2'],
|
|
|
|
['loraR', 'number', 'LoRA R'],
|
|
|
|
['loraAlpha', 'number', 'LoRA Alpha'],
|
|
|
|
['loraDropout', 'number', 'LoRA Dropout'],
|
|
|
|
['beta1', 'any', ''],
|
|
|
|
['preFfn', 'boolean', 'Pre-FFN'],
|
|
|
|
['headQk', 'boolean', 'Head QK']
|
|
|
|
];
|
|
|
|
|
|
|
|
export const wslHandler = (data: string) => {
|
|
|
|
if (data) {
|
|
|
|
addWslMessage(data);
|
|
|
|
parseLossData(data);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
const addWslMessage = (message: string) => {
|
|
|
|
const newData = commonStore.wslStdout + '\n' + message;
|
|
|
|
let lines = newData.split('\n');
|
|
|
|
const result = lines.slice(-100).join('\n');
|
|
|
|
commonStore.setWslStdout(result);
|
|
|
|
};
|
|
|
|
|
|
|
|
const TerminalDisplay: FC = observer(() => {
|
|
|
|
const bodyRef = useRef<HTMLDivElement>(null);
|
|
|
|
|
|
|
|
const scrollToBottom = () => {
|
|
|
|
if (bodyRef.current)
|
|
|
|
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
|
|
|
};
|
|
|
|
|
|
|
|
useEffect(() => {
|
|
|
|
scrollToBottom();
|
|
|
|
});
|
|
|
|
|
|
|
|
return (
|
|
|
|
<div ref={bodyRef} className="grow overflow-x-hidden overflow-y-auto border-gray-500 border-2 rounded-md">
|
|
|
|
<div className="whitespace-pre-line">
|
|
|
|
{commonStore.wslStdout}
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
);
|
|
|
|
});
|
|
|
|
|
|
|
|
const Terminal: FC = observer(() => {
|
|
|
|
const { t } = useTranslation();
|
|
|
|
const [input, setInput] = useState('');
|
|
|
|
|
|
|
|
const handleKeyDown = (e: any) => {
|
|
|
|
e.stopPropagation();
|
|
|
|
if (e.keyCode === 13) {
|
|
|
|
e.preventDefault();
|
|
|
|
if (!input) return;
|
|
|
|
|
|
|
|
WslStart().then(() => {
|
|
|
|
addWslMessage('WSL> ' + input);
|
|
|
|
setInput('');
|
2023-07-03 13:40:16 +00:00
|
|
|
WslCommand(input).catch((e: any) => {
|
2023-07-03 09:41:47 +00:00
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
2023-07-03 13:40:16 +00:00
|
|
|
}).catch((e: any) => {
|
2023-07-03 09:41:47 +00:00
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
return (
|
|
|
|
<div className="flex flex-col h-full gap-4">
|
|
|
|
<TerminalDisplay />
|
|
|
|
<div className="flex gap-2 items-center">
|
|
|
|
WSL:
|
|
|
|
<Input className="grow" value={input} onChange={(e) => {
|
|
|
|
setInput(e.target.value);
|
|
|
|
}} onKeyDown={handleKeyDown}></Input>
|
|
|
|
<Button onClick={() => {
|
|
|
|
WslStop().then(() => {
|
|
|
|
toast(t('Command Stopped'), { type: 'success' });
|
2023-07-03 13:40:16 +00:00
|
|
|
}).catch((e: any) => {
|
2023-07-03 09:41:47 +00:00
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
}}>
|
|
|
|
{t('Stop')}
|
|
|
|
</Button>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
);
|
|
|
|
});
|
|
|
|
|
|
|
|
const LoraFinetune: FC = observer(() => {
|
2023-05-22 02:52:06 +00:00
|
|
|
const { t } = useTranslation();
|
2023-07-03 09:41:47 +00:00
|
|
|
const navigate = useNavigate();
|
|
|
|
const chartRef = useRef<ChartJSOrUndefined<'line', (number | null)[], string>>(null);
|
|
|
|
|
|
|
|
const dataParams = commonStore.dataProcessParams;
|
|
|
|
const loraParams = commonStore.loraFinetuneParams;
|
|
|
|
|
|
|
|
if (chartRef.current)
|
|
|
|
chartLine = chartRef.current;
|
|
|
|
|
|
|
|
const setDataParams = (newParams: Partial<DataProcessParameters>) => {
|
|
|
|
commonStore.setDataProcessParams({
|
|
|
|
...dataParams,
|
|
|
|
...newParams
|
|
|
|
});
|
|
|
|
};
|
|
|
|
|
|
|
|
const setLoraParams = (newParams: Partial<LoraFinetuneParameters>) => {
|
|
|
|
commonStore.setLoraFinetuneParameters({
|
|
|
|
...loraParams,
|
|
|
|
...newParams
|
|
|
|
});
|
|
|
|
};
|
|
|
|
|
|
|
|
useEffect(() => {
|
|
|
|
if (loraParams.baseModel === '')
|
|
|
|
setLoraParams({
|
|
|
|
baseModel: commonStore.modelSourceList.find(m => m.isComplete)?.name || ''
|
|
|
|
});
|
|
|
|
}, []);
|
|
|
|
|
2023-07-03 13:40:16 +00:00
|
|
|
const StartLoraFinetune = async () => {
|
|
|
|
const ok = await checkDependencies(navigate);
|
|
|
|
if (!ok)
|
|
|
|
return;
|
|
|
|
|
|
|
|
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
|
|
|
|
if (!await FileExists(convertedDataPath + '.idx')) {
|
|
|
|
toast(t('Please convert data first.'), { type: 'error' });
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2023-07-03 09:41:47 +00:00
|
|
|
WslIsEnabled().then(() => {
|
2023-07-03 13:40:16 +00:00
|
|
|
WslStart().then(() => {
|
|
|
|
setTimeout(WindowShow, 1000);
|
|
|
|
|
|
|
|
let ctxLen = loraParams.ctxLen;
|
|
|
|
if (dataParams.dataPath === 'finetune/data/sample.jsonl') {
|
|
|
|
ctxLen = 150;
|
|
|
|
toast(t('You are using sample data for training. For formal training, please make sure to create your own jsonl file.'), {
|
|
|
|
type: 'info',
|
|
|
|
autoClose: 6000
|
|
|
|
});
|
2023-07-03 09:41:47 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
commonStore.setChartData({
|
|
|
|
labels: [],
|
|
|
|
datasets: [
|
|
|
|
{
|
|
|
|
label: 'Loss',
|
|
|
|
data: [],
|
|
|
|
borderColor: 'rgb(53, 162, 235)',
|
|
|
|
backgroundColor: 'rgba(53, 162, 235, 0.5)'
|
|
|
|
}
|
|
|
|
]
|
|
|
|
});
|
|
|
|
WslCommand(`export cnMirror=${commonStore.settings.cnMirror ? '1' : '0'} ` +
|
|
|
|
`&& export loadModel=models/${loraParams.baseModel} ` +
|
|
|
|
`&& chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` +
|
|
|
|
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
|
|
|
|
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
|
|
|
`--data_file ${convertedDataPath} ` +
|
|
|
|
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
|
2023-07-03 13:40:16 +00:00
|
|
|
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
2023-07-03 09:41:47 +00:00
|
|
|
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
|
|
|
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
|
|
|
`--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` +
|
|
|
|
`--warmup_steps ${loraParams.warmupSteps} ` +
|
|
|
|
`--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` +
|
|
|
|
`--devices ${loraParams.devices} --precision ${loraParams.precision} ` +
|
|
|
|
`--grad_cp ${loraParams.gradCp ? '1' : '0'} ` +
|
2023-07-03 13:40:16 +00:00
|
|
|
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e: any) => {
|
2023-07-03 09:41:47 +00:00
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
}).catch(e => {
|
|
|
|
const msg = e.message || e;
|
|
|
|
if (msg === 'ubuntu not found') {
|
2023-07-03 13:40:16 +00:00
|
|
|
WindowShow();
|
2023-07-03 09:41:47 +00:00
|
|
|
toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => {
|
|
|
|
WslInstallUbuntu().then(() => {
|
2023-07-03 13:40:16 +00:00
|
|
|
WindowShow();
|
|
|
|
toast(t('Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button'), {
|
|
|
|
type: 'info',
|
|
|
|
autoClose: 10000
|
|
|
|
});
|
2023-07-03 09:41:47 +00:00
|
|
|
});
|
|
|
|
});
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}).catch(e => {
|
|
|
|
const msg = e.message || e;
|
|
|
|
|
|
|
|
const enableWsl = (forceMode: boolean) => {
|
2023-07-03 13:40:16 +00:00
|
|
|
WindowShow();
|
2023-07-03 09:41:47 +00:00
|
|
|
toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => {
|
|
|
|
WslEnable(forceMode).then(() => {
|
2023-07-03 13:40:16 +00:00
|
|
|
WindowShow();
|
2023-07-03 09:41:47 +00:00
|
|
|
toast(t('After installation, please restart your computer to enable WSL'), {
|
|
|
|
type: 'info',
|
|
|
|
autoClose: false
|
|
|
|
});
|
|
|
|
}).catch(e => {
|
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
});
|
|
|
|
};
|
|
|
|
|
|
|
|
if (msg === 'wsl is not enabled') {
|
|
|
|
enableWsl(false);
|
|
|
|
} else if (msg.includes('wsl.state: The system cannot find the file')) {
|
|
|
|
enableWsl(true);
|
|
|
|
} else {
|
|
|
|
toast(msg, { type: 'error' });
|
|
|
|
}
|
|
|
|
});
|
|
|
|
};
|
2023-05-18 12:48:53 +00:00
|
|
|
|
2023-05-05 15:23:34 +00:00
|
|
|
return (
|
2023-07-03 09:41:47 +00:00
|
|
|
<div className="flex flex-col h-full w-full gap-2">
|
|
|
|
{(commonStore.wslStdout.length > 0 || commonStore.chartData.labels!.length !== 0) &&
|
|
|
|
<div className="flex" style={{ height: '35%' }}>
|
|
|
|
{commonStore.wslStdout.length > 0 && commonStore.chartData.labels!.length === 0 && <TerminalDisplay />}
|
|
|
|
{commonStore.chartData.labels!.length !== 0 &&
|
|
|
|
<Line ref={chartRef} data={commonStore.chartData} options={{
|
|
|
|
responsive: true,
|
|
|
|
showLine: true,
|
|
|
|
plugins: {
|
|
|
|
legend: {
|
|
|
|
position: 'right',
|
|
|
|
align: 'start'
|
|
|
|
},
|
|
|
|
title: {
|
|
|
|
display: true,
|
|
|
|
text: commonStore.chartTitle
|
|
|
|
}
|
|
|
|
},
|
|
|
|
scales: {
|
|
|
|
y: {
|
|
|
|
beginAtZero: true
|
|
|
|
}
|
|
|
|
},
|
|
|
|
maintainAspectRatio: false
|
|
|
|
}} style={{ width: '100%' }} />}
|
|
|
|
</div>
|
|
|
|
}
|
|
|
|
<div>
|
|
|
|
<Section
|
|
|
|
title={t('Data Process')}
|
|
|
|
content={
|
|
|
|
<div className="flex flex-col gap-2">
|
|
|
|
<Labeled flex label={t('Data Path')}
|
|
|
|
content={
|
|
|
|
<div className="grow flex gap-2">
|
|
|
|
<Input className="grow ml-2" value={dataParams.dataPath}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setDataParams({ dataPath: data.value });
|
|
|
|
}} />
|
|
|
|
<ToolTipButton desc={t('Open Folder')} icon={<Folder20Regular />} onClick={() => {
|
|
|
|
OpenFileFolder(dataParams.dataPath, false);
|
|
|
|
}} />
|
|
|
|
</div>
|
|
|
|
} />
|
|
|
|
<div className="flex gap-2 items-center">
|
|
|
|
{t('Vocab Path')}
|
|
|
|
<Input className="grow" style={{ minWidth: 0 }} value={dataParams.vocabPath}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setDataParams({ vocabPath: data.value });
|
|
|
|
}} />
|
|
|
|
<Button appearance="secondary" size="large" onClick={() => {
|
|
|
|
ConvertData(commonStore.settings.customPythonPath, dataParams.dataPath,
|
|
|
|
'./finetune/json2binidx_tool/data/' + dataParams.dataPath.split('/').pop()!.split('.')[0],
|
|
|
|
dataParams.vocabPath).then(() => {
|
|
|
|
toast(t('Convert Data successfully'), { type: 'success' });
|
2023-07-03 13:40:16 +00:00
|
|
|
}).catch((e: any) => {
|
2023-07-03 09:41:47 +00:00
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
}}>{t('Convert')}</Button>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
}
|
|
|
|
/>
|
|
|
|
</div>
|
|
|
|
<Section
|
|
|
|
title={t('Train Parameters')}
|
|
|
|
content={
|
|
|
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-2">
|
|
|
|
<div className="flex gap-2 items-center">
|
|
|
|
{t('Base Model')}
|
|
|
|
<Select style={{ minWidth: 0 }} className="grow"
|
|
|
|
value={loraParams.baseModel}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setLoraParams({
|
|
|
|
baseModel: data.value
|
|
|
|
});
|
|
|
|
}}>
|
|
|
|
{commonStore.modelSourceList.map((modelItem, index) =>
|
|
|
|
modelItem.isComplete && <option key={index} value={modelItem.name}>{modelItem.name}</option>
|
|
|
|
)}
|
|
|
|
</Select>
|
|
|
|
<ToolTipButton desc={t('Manage Models')} icon={<DataUsageSettings20Regular />} onClick={() => {
|
|
|
|
navigate({ pathname: '/models' });
|
|
|
|
}} />
|
|
|
|
</div>
|
|
|
|
<div className="flex gap-2 items-center">
|
|
|
|
{t('LoRA Model')}
|
|
|
|
<Select style={{ minWidth: 0 }} className="grow"
|
|
|
|
value={loraParams.loraLoad}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setLoraParams({
|
|
|
|
loraLoad: data.value
|
|
|
|
});
|
|
|
|
}}>
|
|
|
|
<option value="">{t('None')}</option>
|
|
|
|
{commonStore.loraModels.map((name, index) =>
|
|
|
|
<option key={index} value={name}>{name}</option>
|
|
|
|
)}
|
|
|
|
</Select>
|
2023-07-03 13:40:16 +00:00
|
|
|
<Button onClick={async () => {
|
|
|
|
const ok = await checkDependencies(navigate);
|
|
|
|
if (!ok)
|
|
|
|
return;
|
|
|
|
if (loraParams.loraLoad) {
|
|
|
|
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
|
|
|
|
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
|
|
|
|
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
|
|
|
|
toast(t('Merge model successfully'), { type: 'success' });
|
|
|
|
refreshLocalModels({ models: commonStore.modelSourceList }, false);
|
|
|
|
}).catch((e: any) => {
|
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
} else {
|
|
|
|
toast(t('Please select a LoRA model'), { type: 'info' });
|
|
|
|
}
|
2023-07-03 09:41:47 +00:00
|
|
|
}}>{t('Merge Model')}</Button>
|
|
|
|
</div>
|
|
|
|
{
|
|
|
|
loraFinetuneParametersOptions.map(([key, type, name], index) => {
|
|
|
|
return (
|
|
|
|
<Labeled key={index} label={t(name)} content={
|
|
|
|
type === 'number' ?
|
|
|
|
<Input type="number" className="grow" value={loraParams[key].toString()}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setLoraParams({
|
|
|
|
[key]: Number(data.value)
|
|
|
|
});
|
|
|
|
}} /> :
|
|
|
|
type === 'boolean' ?
|
|
|
|
<Switch className="grow" checked={loraParams[key] as boolean}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setLoraParams({
|
|
|
|
[key]: data.checked
|
|
|
|
});
|
|
|
|
}} /> :
|
|
|
|
type === 'string' ?
|
|
|
|
<Input className="grow" value={loraParams[key].toString()}
|
|
|
|
onChange={(e, data) => {
|
|
|
|
setLoraParams({
|
|
|
|
[key]: data.value
|
|
|
|
});
|
|
|
|
}} /> :
|
|
|
|
type === 'LoraFinetunePrecision' ?
|
|
|
|
<Dropdown style={{ minWidth: 0 }} className="grow"
|
|
|
|
value={loraParams[key].toString()}
|
|
|
|
selectedOptions={[loraParams[key].toString()]}
|
|
|
|
onOptionSelect={(_, data) => {
|
|
|
|
if (data.optionText) {
|
|
|
|
setLoraParams({
|
|
|
|
precision: data.optionText as LoraFinetunePrecision
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}}
|
|
|
|
>
|
|
|
|
<Option>bf16</Option>
|
|
|
|
<Option>fp16</Option>
|
|
|
|
<Option>fp32</Option>
|
|
|
|
<Option>tf32</Option>
|
|
|
|
</Dropdown>
|
|
|
|
: <div />
|
|
|
|
} />
|
|
|
|
);
|
|
|
|
})
|
|
|
|
}
|
|
|
|
</div>
|
|
|
|
}
|
|
|
|
/>
|
|
|
|
<div className="grow" />
|
|
|
|
<div className="flex gap-2">
|
|
|
|
<div className="grow" />
|
|
|
|
<Button appearance="secondary" size="large" onClick={() => {
|
|
|
|
WslStop().then(() => {
|
|
|
|
toast(t('Command Stopped'), { type: 'success' });
|
2023-07-03 13:40:16 +00:00
|
|
|
}).catch((e: any) => {
|
2023-07-03 09:41:47 +00:00
|
|
|
toast((e.message || e), { type: 'error' });
|
|
|
|
});
|
|
|
|
}}>{t('Stop')}</Button>
|
|
|
|
<Button appearance="primary" size="large" onClick={StartLoraFinetune}>{t('Train')}</Button>
|
|
|
|
</div>
|
2023-05-05 15:23:34 +00:00
|
|
|
</div>
|
|
|
|
);
|
2023-07-03 09:41:47 +00:00
|
|
|
});
|
|
|
|
|
|
|
|
type TrainNavigationItem = {
|
|
|
|
element: ReactElement;
|
|
|
|
};
|
|
|
|
|
|
|
|
const pages: { [label: string]: TrainNavigationItem } = {
|
|
|
|
'LoRA Finetune': {
|
|
|
|
element: <LoraFinetune />
|
|
|
|
},
|
|
|
|
WSL: {
|
|
|
|
element: <Terminal />
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
export const Train: FC = () => {
|
|
|
|
const { t } = useTranslation();
|
|
|
|
const [tab, setTab] = useState('LoRA Finetune');
|
|
|
|
|
|
|
|
const selectTab: SelectTabEventHandler = (e, data) =>
|
|
|
|
typeof data.value === 'string' ? setTab(data.value) : null;
|
|
|
|
|
|
|
|
return <div className="flex flex-col gap-2 w-full h-full">
|
|
|
|
<TabList
|
|
|
|
size="small"
|
|
|
|
appearance="subtle"
|
|
|
|
selectedValue={tab}
|
|
|
|
onTabSelect={selectTab}
|
|
|
|
>
|
|
|
|
{Object.entries(pages).map(([label]) => (
|
|
|
|
<Tab key={label} value={label}>
|
|
|
|
{t(label)}
|
|
|
|
</Tab>
|
|
|
|
))}
|
|
|
|
</TabList>
|
|
|
|
<div className="grow overflow-hidden">
|
|
|
|
{pages[tab].element}
|
|
|
|
</div>
|
|
|
|
</div>;
|
2023-05-05 15:23:34 +00:00
|
|
|
};
|