lora finetune (need to be refactored)
This commit is contained in:
@@ -189,5 +189,37 @@
|
||||
"user": "用户",
|
||||
"assistant": "AI",
|
||||
"system": "系统",
|
||||
"Regenerate": "重新生成"
|
||||
"Regenerate": "重新生成",
|
||||
"LoRA Finetune": "LoRA微调",
|
||||
"Command Stopped": "命令已终止",
|
||||
"Please convert data first.": "请先转换数据",
|
||||
"Ubuntu is not installed, do you want to install it?": "Ubuntu未安装,是否安装?",
|
||||
"Install Ubuntu": "安装Ubuntu",
|
||||
"Please install Ubuntu using Microsoft Store": "请用Microsoft Store安装Ubuntu",
|
||||
"WSL is not enabled, do you want to enable it?": "WSL未启用,是否启用?",
|
||||
"Enable WSL": "启用WSL",
|
||||
"After installation, please restart your computer to enable WSL": "安装完成后,请重启电脑以启用WSL",
|
||||
"Data Process": "数据处理",
|
||||
"Data Path": "数据路径",
|
||||
"Vocab Path": "词表路径",
|
||||
"Train Parameters": "训练参数",
|
||||
"Base Model": "基底模型",
|
||||
"LoRA Model": "LoRA模型",
|
||||
"Merge Model": "合并模型",
|
||||
"Devices": "显卡数量",
|
||||
"Gradient Checkpoint": "梯度检查点标志",
|
||||
"Context Length": "上下文长度",
|
||||
"Epoch Steps": "每轮训练步数",
|
||||
"Epoch Count": "训练轮次",
|
||||
"Epoch Begin": "起始轮次",
|
||||
"Epoch Save": "保存间隔轮次",
|
||||
"Learning Rate Init": "初始学习率",
|
||||
"Learning Rate Final": "最终学习率",
|
||||
"Micro Batch Size": "微批次大小",
|
||||
"Accumulate Gradient Batches": "梯度累积批次",
|
||||
"Warmup Steps": "学习率预热步数",
|
||||
"Pre-FFN": "前馈网络预处理",
|
||||
"None": "空",
|
||||
"Merge model successfully": "合并模型成功",
|
||||
"Convert Data successfully": "数据转换成功"
|
||||
}
|
||||
@@ -1,13 +1,542 @@
|
||||
import React, { FC } from 'react';
|
||||
import { Text } from '@fluentui/react-components';
|
||||
import React, { FC, ReactElement, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@fluentui/react-components';
|
||||
import {
|
||||
ConvertData,
|
||||
FileExists,
|
||||
MergeLora,
|
||||
OpenFileFolder,
|
||||
WslCommand,
|
||||
WslEnable,
|
||||
WslInstallUbuntu,
|
||||
WslIsEnabled,
|
||||
WslStart,
|
||||
WslStop
|
||||
} from '../../wailsjs/go/backend_golang/App';
|
||||
import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { observer } from 'mobx-react-lite';
|
||||
import { SelectTabEventHandler } from '@fluentui/react-tabs';
|
||||
import { refreshLocalModels, toastWithButton } from '../utils';
|
||||
import { Section } from '../components/Section';
|
||||
import { Labeled } from '../components/Labeled';
|
||||
import { ToolTipButton } from '../components/ToolTipButton';
|
||||
import { DataUsageSettings20Regular, Folder20Regular } from '@fluentui/react-icons';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { Precision } from './Configs';
|
||||
import {
|
||||
CategoryScale,
|
||||
Chart as ChartJS,
|
||||
Legend,
|
||||
LinearScale,
|
||||
LineElement,
|
||||
PointElement,
|
||||
Title,
|
||||
Tooltip
|
||||
} from 'chart.js';
|
||||
import { Line } from 'react-chartjs-2';
|
||||
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
|
||||
|
||||
ChartJS.register(
|
||||
CategoryScale,
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Tooltip,
|
||||
Title,
|
||||
Legend
|
||||
);
|
||||
|
||||
const parseLossData = (data: string) => {
|
||||
const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\d+:\d+)<(\d+:\d+),\s+(\d+.\d+it\/s), loss=(\d+.\d+),[\s\S]*\]/g;
|
||||
const matches = Array.from(data.matchAll(regex));
|
||||
if (matches.length === 0)
|
||||
return;
|
||||
const lastMatch = matches[matches.length - 1];
|
||||
const epoch = parseInt(lastMatch[1]);
|
||||
const loss = parseFloat(lastMatch[8]);
|
||||
commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`);
|
||||
addLossDataToChart(epoch, loss);
|
||||
};
|
||||
|
||||
let chartLine: ChartJSOrUndefined<'line', (number | null)[], string>;
|
||||
|
||||
const addLossDataToChart = (epoch: number, loss: number) => {
|
||||
const epochIndex = commonStore.chartData.labels!.findIndex(l => l.includes(epoch.toString()));
|
||||
if (epochIndex === -1) {
|
||||
if (epoch === 0) {
|
||||
commonStore.chartData.labels!.push('Init');
|
||||
commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss];
|
||||
}
|
||||
commonStore.chartData.labels!.push('Epoch ' + epoch.toString());
|
||||
commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss];
|
||||
} else {
|
||||
if (chartLine) {
|
||||
const newData = [...commonStore.chartData.datasets[0].data];
|
||||
newData[epochIndex] = loss;
|
||||
chartLine.data.datasets[0].data = newData;
|
||||
chartLine.update();
|
||||
}
|
||||
}
|
||||
commonStore.setChartData(commonStore.chartData);
|
||||
};
|
||||
|
||||
export type DataProcessParameters = {
|
||||
dataPath: string;
|
||||
vocabPath: string;
|
||||
}
|
||||
|
||||
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'fp32' | 'tf32';
|
||||
|
||||
export type LoraFinetuneParameters = {
|
||||
baseModel: string;
|
||||
ctxLen: number;
|
||||
epochSteps: number;
|
||||
epochCount: number;
|
||||
epochBegin: number;
|
||||
epochSave: number;
|
||||
microBsz: number;
|
||||
accumGradBatches: number;
|
||||
preFfn: boolean;
|
||||
headQk: boolean;
|
||||
lrInit: string;
|
||||
lrFinal: string;
|
||||
warmupSteps: number;
|
||||
beta1: number;
|
||||
beta2: number;
|
||||
adamEps: string;
|
||||
devices: number;
|
||||
precision: LoraFinetunePrecision;
|
||||
gradCp: boolean;
|
||||
loraR: number;
|
||||
loraAlpha: number;
|
||||
loraDropout: number;
|
||||
loraLoad: string
|
||||
}
|
||||
|
||||
const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, type: string, name: string]> = [
|
||||
['devices', 'number', 'Devices'],
|
||||
['precision', 'LoraFinetunePrecision', 'Precision'],
|
||||
['gradCp', 'boolean', 'Gradient Checkpoint'],
|
||||
['ctxLen', 'number', 'Context Length'],
|
||||
['epochSteps', 'number', 'Epoch Steps'],
|
||||
['epochCount', 'number', 'Epoch Count'],
|
||||
['epochBegin', 'number', 'Epoch Begin'],
|
||||
['epochSave', 'number', 'Epoch Save'],
|
||||
['lrInit', 'string', 'Learning Rate Init'],
|
||||
['lrFinal', 'string', 'Learning Rate Final'],
|
||||
['microBsz', 'number', 'Micro Batch Size'],
|
||||
['accumGradBatches', 'number', 'Accumulate Gradient Batches'],
|
||||
['warmupSteps', 'number', 'Warmup Steps'],
|
||||
['adamEps', 'string', 'Adam Epsilon'],
|
||||
['beta1', 'number', 'Beta 1'],
|
||||
['beta2', 'number', 'Beta 2'],
|
||||
['loraR', 'number', 'LoRA R'],
|
||||
['loraAlpha', 'number', 'LoRA Alpha'],
|
||||
['loraDropout', 'number', 'LoRA Dropout'],
|
||||
['beta1', 'any', ''],
|
||||
['preFfn', 'boolean', 'Pre-FFN'],
|
||||
['headQk', 'boolean', 'Head QK']
|
||||
];
|
||||
|
||||
export const wslHandler = (data: string) => {
|
||||
if (data) {
|
||||
addWslMessage(data);
|
||||
parseLossData(data);
|
||||
}
|
||||
};
|
||||
|
||||
const addWslMessage = (message: string) => {
|
||||
const newData = commonStore.wslStdout + '\n' + message;
|
||||
let lines = newData.split('\n');
|
||||
const result = lines.slice(-100).join('\n');
|
||||
commonStore.setWslStdout(result);
|
||||
};
|
||||
|
||||
const TerminalDisplay: FC = observer(() => {
|
||||
const bodyRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (bodyRef.current)
|
||||
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
});
|
||||
|
||||
return (
|
||||
<div ref={bodyRef} className="grow overflow-x-hidden overflow-y-auto border-gray-500 border-2 rounded-md">
|
||||
<div className="whitespace-pre-line">
|
||||
{commonStore.wslStdout}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const Terminal: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
const [input, setInput] = useState('');
|
||||
|
||||
const handleKeyDown = (e: any) => {
|
||||
e.stopPropagation();
|
||||
if (e.keyCode === 13) {
|
||||
e.preventDefault();
|
||||
if (!input) return;
|
||||
|
||||
WslStart().then(() => {
|
||||
addWslMessage('WSL> ' + input);
|
||||
setInput('');
|
||||
WslCommand(input).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full gap-4">
|
||||
<TerminalDisplay />
|
||||
<div className="flex gap-2 items-center">
|
||||
WSL:
|
||||
<Input className="grow" value={input} onChange={(e) => {
|
||||
setInput(e.target.value);
|
||||
}} onKeyDown={handleKeyDown}></Input>
|
||||
<Button onClick={() => {
|
||||
WslStop().then(() => {
|
||||
toast(t('Command Stopped'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>
|
||||
{t('Stop')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const LoraFinetune: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const chartRef = useRef<ChartJSOrUndefined<'line', (number | null)[], string>>(null);
|
||||
|
||||
const dataParams = commonStore.dataProcessParams;
|
||||
const loraParams = commonStore.loraFinetuneParams;
|
||||
|
||||
if (chartRef.current)
|
||||
chartLine = chartRef.current;
|
||||
|
||||
const setDataParams = (newParams: Partial<DataProcessParameters>) => {
|
||||
commonStore.setDataProcessParams({
|
||||
...dataParams,
|
||||
...newParams
|
||||
});
|
||||
};
|
||||
|
||||
const setLoraParams = (newParams: Partial<LoraFinetuneParameters>) => {
|
||||
commonStore.setLoraFinetuneParameters({
|
||||
...loraParams,
|
||||
...newParams
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (loraParams.baseModel === '')
|
||||
setLoraParams({
|
||||
baseModel: commonStore.modelSourceList.find(m => m.isComplete)?.name || ''
|
||||
});
|
||||
}, []);
|
||||
|
||||
const StartLoraFinetune = () => {
|
||||
WslIsEnabled().then(() => {
|
||||
WslStart().then(async () => {
|
||||
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
|
||||
if (!await FileExists(convertedDataPath + '.idx')) {
|
||||
toast(t('Please convert data first.'), { type: 'error' });
|
||||
return;
|
||||
}
|
||||
|
||||
commonStore.setChartData({
|
||||
labels: [],
|
||||
datasets: [
|
||||
{
|
||||
label: 'Loss',
|
||||
data: [],
|
||||
borderColor: 'rgb(53, 162, 235)',
|
||||
backgroundColor: 'rgba(53, 162, 235, 0.5)'
|
||||
}
|
||||
]
|
||||
});
|
||||
WslCommand(`export cnMirror=${commonStore.settings.cnMirror ? '1' : '0'} ` +
|
||||
`&& export loadModel=models/${loraParams.baseModel} ` +
|
||||
`&& chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` +
|
||||
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
|
||||
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
||||
`--data_file ${convertedDataPath} ` +
|
||||
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
|
||||
`--ctx_len ${loraParams.ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
||||
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
||||
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
||||
`--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` +
|
||||
`--warmup_steps ${loraParams.warmupSteps} ` +
|
||||
`--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` +
|
||||
`--devices ${loraParams.devices} --precision ${loraParams.precision} ` +
|
||||
`--grad_cp ${loraParams.gradCp ? '1' : '0'} ` +
|
||||
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}).catch(e => {
|
||||
const msg = e.message || e;
|
||||
if (msg === 'ubuntu not found') {
|
||||
toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => {
|
||||
WslInstallUbuntu().then(() => {
|
||||
toast(t('Please install Ubuntu using Microsoft Store'), { type: 'info', autoClose: 6000 });
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}).catch(e => {
|
||||
const msg = e.message || e;
|
||||
|
||||
const enableWsl = (forceMode: boolean) => {
|
||||
toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => {
|
||||
WslEnable(forceMode).then(() => {
|
||||
toast(t('After installation, please restart your computer to enable WSL'), {
|
||||
type: 'info',
|
||||
autoClose: false
|
||||
});
|
||||
}).catch(e => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
if (msg === 'wsl is not enabled') {
|
||||
enableWsl(false);
|
||||
} else if (msg.includes('wsl.state: The system cannot find the file')) {
|
||||
enableWsl(true);
|
||||
} else {
|
||||
toast(msg, { type: 'error' });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full w-full gap-2">
|
||||
{(commonStore.wslStdout.length > 0 || commonStore.chartData.labels!.length !== 0) &&
|
||||
<div className="flex" style={{ height: '35%' }}>
|
||||
{commonStore.wslStdout.length > 0 && commonStore.chartData.labels!.length === 0 && <TerminalDisplay />}
|
||||
{commonStore.chartData.labels!.length !== 0 &&
|
||||
<Line ref={chartRef} data={commonStore.chartData} options={{
|
||||
responsive: true,
|
||||
showLine: true,
|
||||
plugins: {
|
||||
legend: {
|
||||
position: 'right',
|
||||
align: 'start'
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: commonStore.chartTitle
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
y: {
|
||||
beginAtZero: true
|
||||
}
|
||||
},
|
||||
maintainAspectRatio: false
|
||||
}} style={{ width: '100%' }} />}
|
||||
</div>
|
||||
}
|
||||
<div>
|
||||
<Section
|
||||
title={t('Data Process')}
|
||||
content={
|
||||
<div className="flex flex-col gap-2">
|
||||
<Labeled flex label={t('Data Path')}
|
||||
content={
|
||||
<div className="grow flex gap-2">
|
||||
<Input className="grow ml-2" value={dataParams.dataPath}
|
||||
onChange={(e, data) => {
|
||||
setDataParams({ dataPath: data.value });
|
||||
}} />
|
||||
<ToolTipButton desc={t('Open Folder')} icon={<Folder20Regular />} onClick={() => {
|
||||
OpenFileFolder(dataParams.dataPath, false);
|
||||
}} />
|
||||
</div>
|
||||
} />
|
||||
<div className="flex gap-2 items-center">
|
||||
{t('Vocab Path')}
|
||||
<Input className="grow" style={{ minWidth: 0 }} value={dataParams.vocabPath}
|
||||
onChange={(e, data) => {
|
||||
setDataParams({ vocabPath: data.value });
|
||||
}} />
|
||||
<Button appearance="secondary" size="large" onClick={() => {
|
||||
ConvertData(commonStore.settings.customPythonPath, dataParams.dataPath,
|
||||
'./finetune/json2binidx_tool/data/' + dataParams.dataPath.split('/').pop()!.split('.')[0],
|
||||
dataParams.vocabPath).then(() => {
|
||||
toast(t('Convert Data successfully'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Convert')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Section
|
||||
title={t('Train Parameters')}
|
||||
content={
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-2">
|
||||
<div className="flex gap-2 items-center">
|
||||
{t('Base Model')}
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={loraParams.baseModel}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
baseModel: data.value
|
||||
});
|
||||
}}>
|
||||
{commonStore.modelSourceList.map((modelItem, index) =>
|
||||
modelItem.isComplete && <option key={index} value={modelItem.name}>{modelItem.name}</option>
|
||||
)}
|
||||
</Select>
|
||||
<ToolTipButton desc={t('Manage Models')} icon={<DataUsageSettings20Regular />} onClick={() => {
|
||||
navigate({ pathname: '/models' });
|
||||
}} />
|
||||
</div>
|
||||
<div className="flex gap-2 items-center">
|
||||
{t('LoRA Model')}
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={loraParams.loraLoad}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
loraLoad: data.value
|
||||
});
|
||||
}}>
|
||||
<option value="">{t('None')}</option>
|
||||
{commonStore.loraModels.map((name, index) =>
|
||||
<option key={index} value={name}>{name}</option>
|
||||
)}
|
||||
</Select>
|
||||
<Button onClick={() => {
|
||||
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
|
||||
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
|
||||
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
|
||||
toast(t('Merge model successfully'), { type: 'success' });
|
||||
refreshLocalModels({ models: commonStore.modelSourceList }, false);
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Merge Model')}</Button>
|
||||
</div>
|
||||
{
|
||||
loraFinetuneParametersOptions.map(([key, type, name], index) => {
|
||||
return (
|
||||
<Labeled key={index} label={t(name)} content={
|
||||
type === 'number' ?
|
||||
<Input type="number" className="grow" value={loraParams[key].toString()}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
[key]: Number(data.value)
|
||||
});
|
||||
}} /> :
|
||||
type === 'boolean' ?
|
||||
<Switch className="grow" checked={loraParams[key] as boolean}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
[key]: data.checked
|
||||
});
|
||||
}} /> :
|
||||
type === 'string' ?
|
||||
<Input className="grow" value={loraParams[key].toString()}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
[key]: data.value
|
||||
});
|
||||
}} /> :
|
||||
type === 'LoraFinetunePrecision' ?
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow"
|
||||
value={loraParams[key].toString()}
|
||||
selectedOptions={[loraParams[key].toString()]}
|
||||
onOptionSelect={(_, data) => {
|
||||
if (data.optionText) {
|
||||
setLoraParams({
|
||||
precision: data.optionText as LoraFinetunePrecision
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Option>bf16</Option>
|
||||
<Option>fp16</Option>
|
||||
<Option>fp32</Option>
|
||||
<Option>tf32</Option>
|
||||
</Dropdown>
|
||||
: <div />
|
||||
} />
|
||||
);
|
||||
})
|
||||
}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
<div className="grow" />
|
||||
<div className="flex gap-2">
|
||||
<div className="grow" />
|
||||
<Button appearance="secondary" size="large" onClick={() => {
|
||||
WslStop().then(() => {
|
||||
toast(t('Command Stopped'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Stop')}</Button>
|
||||
<Button appearance="primary" size="large" onClick={StartLoraFinetune}>{t('Train')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
type TrainNavigationItem = {
|
||||
element: ReactElement;
|
||||
};
|
||||
|
||||
const pages: { [label: string]: TrainNavigationItem } = {
|
||||
'LoRA Finetune': {
|
||||
element: <LoraFinetune />
|
||||
},
|
||||
WSL: {
|
||||
element: <Terminal />
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const Train: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const [tab, setTab] = useState('LoRA Finetune');
|
||||
|
||||
return (
|
||||
<div className="flex flex-col box-border gap-5 p-2">
|
||||
<Text size={600}>{t('In Development')}</Text>
|
||||
const selectTab: SelectTabEventHandler = (e, data) =>
|
||||
typeof data.value === 'string' ? setTab(data.value) : null;
|
||||
|
||||
return <div className="flex flex-col gap-2 w-full h-full">
|
||||
<TabList
|
||||
size="small"
|
||||
appearance="subtle"
|
||||
selectedValue={tab}
|
||||
onTabSelect={selectTab}
|
||||
>
|
||||
{Object.entries(pages).map(([label]) => (
|
||||
<Tab key={label} value={label}>
|
||||
{t(label)}
|
||||
</Tab>
|
||||
))}
|
||||
</TabList>
|
||||
<div className="grow overflow-hidden">
|
||||
{pages[tab].element}
|
||||
</div>
|
||||
);
|
||||
</div>;
|
||||
};
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import commonStore, { Platform } from './stores/commonStore';
|
||||
import { GetPlatform, ReadJson } from '../wailsjs/go/backend_golang/App';
|
||||
import { GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
|
||||
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshModels } from './utils';
|
||||
import { getStatus } from './apis';
|
||||
import { EventsOn } from '../wailsjs/runtime';
|
||||
import manifest from '../../manifest.json';
|
||||
import { defaultModelConfigs, defaultModelConfigsMac } from './pages/defaultModelConfigs';
|
||||
import { Preset } from './pages/PresetsManager/PresetsButton';
|
||||
import { wslHandler } from './pages/Train';
|
||||
|
||||
export async function startup() {
|
||||
downloadProgramFiles();
|
||||
@@ -13,9 +14,14 @@ export async function startup() {
|
||||
if (data)
|
||||
commonStore.setDownloadList(data);
|
||||
});
|
||||
EventsOn('wsl', wslHandler);
|
||||
EventsOn('wslerr', (e) => {
|
||||
console.log(e);
|
||||
});
|
||||
initLoraModels();
|
||||
|
||||
initPresets();
|
||||
|
||||
|
||||
await GetPlatform().then(p => commonStore.setPlatform(p as Platform));
|
||||
await initConfig();
|
||||
|
||||
@@ -50,6 +56,9 @@ async function initConfig() {
|
||||
if (configData.settings)
|
||||
commonStore.setSettings(configData.settings, false);
|
||||
|
||||
if (configData.dataProcessParams)
|
||||
commonStore.setDataProcessParams(configData.dataProcessParams, false);
|
||||
|
||||
if (configData.modelConfigs && Array.isArray(configData.modelConfigs))
|
||||
commonStore.setModelConfigs(configData.modelConfigs, false);
|
||||
else throw new Error('Invalid config.json');
|
||||
@@ -76,3 +85,24 @@ async function initPresets() {
|
||||
}).catch(() => {
|
||||
});
|
||||
}
|
||||
|
||||
async function initLoraModels() {
|
||||
const refreshLoraModels = () => {
|
||||
ListDirFiles('lora-models').then((data) => {
|
||||
if (!data) return;
|
||||
const loraModels = [];
|
||||
for (const f of data) {
|
||||
if (!f.isDir && f.name.endsWith('.pth')) {
|
||||
loraModels.push(f.name);
|
||||
}
|
||||
}
|
||||
commonStore.setLoraModels(loraModels);
|
||||
});
|
||||
};
|
||||
|
||||
refreshLoraModels();
|
||||
EventsOn('fsnotify', (data: string) => {
|
||||
if (data.includes('lora-models'))
|
||||
refreshLoraModels();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ import { CompletionPreset } from '../pages/Completion';
|
||||
import { defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultModelConfigs';
|
||||
import commonStore from './commonStore';
|
||||
import { Preset } from '../pages/PresetsManager/PresetsButton';
|
||||
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
|
||||
import { ChartData } from 'chart.js';
|
||||
|
||||
export enum ModelStatus {
|
||||
Offline,
|
||||
@@ -30,6 +32,8 @@ export type Status = {
|
||||
|
||||
export type Platform = 'windows' | 'darwin' | 'linux';
|
||||
|
||||
const labels = ['January', 'February', 'March', 'April', 'May', 'June', 'July'];
|
||||
|
||||
class CommonStore {
|
||||
// global
|
||||
status: Status = {
|
||||
@@ -62,6 +66,40 @@ class CommonStore {
|
||||
// downloads
|
||||
downloadList: DownloadStatus[] = [];
|
||||
lastUnfinishedModelDownloads: DownloadStatus[] = [];
|
||||
// train
|
||||
wslStdout: string = '';
|
||||
chartTitle: string = '';
|
||||
chartData: ChartData<'line', (number | null)[], string> = { labels: [], datasets: [] };
|
||||
loraModels: string[] = [];
|
||||
dataProcessParams: DataProcessParameters = {
|
||||
dataPath: 'finetune/data/sample.jsonl',
|
||||
vocabPath: 'backend-python/rwkv_pip/rwkv_vocab_v20230424.txt'
|
||||
};
|
||||
loraFinetuneParams: LoraFinetuneParameters = {
|
||||
baseModel: '',
|
||||
ctxLen: 1024,
|
||||
epochSteps: 1000,
|
||||
epochCount: 20,
|
||||
epochBegin: 0,
|
||||
epochSave: 5,
|
||||
microBsz: 1,
|
||||
accumGradBatches: 8,
|
||||
preFfn: false,
|
||||
headQk: false,
|
||||
lrInit: '5e-5',
|
||||
lrFinal: '5e-5',
|
||||
warmupSteps: 0,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
adamEps: '1e-8',
|
||||
devices: 1,
|
||||
precision: 'bf16',
|
||||
gradCp: false,
|
||||
loraR: 8,
|
||||
loraAlpha: 32,
|
||||
loraDropout: 0.01,
|
||||
loraLoad: ''
|
||||
};
|
||||
// settings
|
||||
advancedCollapsed: boolean = true;
|
||||
settings: SettingsType = {
|
||||
@@ -228,6 +266,34 @@ class CommonStore {
|
||||
setCompletionSubmittedPrompt(value: string) {
|
||||
this.completionSubmittedPrompt = value;
|
||||
}
|
||||
|
||||
setWslStdout(value: string) {
|
||||
this.wslStdout = value;
|
||||
}
|
||||
|
||||
setDataProcessParams(value: DataProcessParameters, saveConfig: boolean = true) {
|
||||
this.dataProcessParams = value;
|
||||
if (saveConfig)
|
||||
saveConfigs();
|
||||
}
|
||||
|
||||
setLoraFinetuneParameters(value: LoraFinetuneParameters, saveConfig: boolean = true) {
|
||||
this.loraFinetuneParams = value;
|
||||
if (saveConfig)
|
||||
saveConfigs();
|
||||
}
|
||||
|
||||
setChartTitle(value: string) {
|
||||
this.chartTitle = value;
|
||||
}
|
||||
|
||||
setChartData(value: ChartData<'line', (number | null)[], string>) {
|
||||
this.chartData = value;
|
||||
}
|
||||
|
||||
setLoraModels(value: string[]) {
|
||||
this.loraModels = value;
|
||||
}
|
||||
}
|
||||
|
||||
export default new CommonStore();
|
||||
@@ -17,6 +17,7 @@ import { Language, Languages, SettingsType } from '../pages/Settings';
|
||||
import { ModelSourceItem } from '../pages/Models';
|
||||
import { ModelConfig, ModelParameters } from '../pages/Configs';
|
||||
import { DownloadStatus } from '../pages/Downloads';
|
||||
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
|
||||
|
||||
export type Cache = {
|
||||
version: string
|
||||
@@ -28,7 +29,9 @@ export type LocalConfig = {
|
||||
modelSourceManifestList: string
|
||||
currentModelConfigIndex: number
|
||||
modelConfigs: ModelConfig[]
|
||||
settings: SettingsType
|
||||
settings: SettingsType,
|
||||
dataProcessParams: DataProcessParameters,
|
||||
loraFinetuneParams: LoraFinetuneParameters
|
||||
}
|
||||
|
||||
export async function refreshBuiltInModels(readCache: boolean = false) {
|
||||
@@ -194,7 +197,9 @@ export const saveConfigs = async () => {
|
||||
modelSourceManifestList: commonStore.modelSourceManifestList,
|
||||
currentModelConfigIndex: commonStore.currentModelConfigIndex,
|
||||
modelConfigs: commonStore.modelConfigs,
|
||||
settings: commonStore.settings
|
||||
settings: commonStore.settings,
|
||||
dataProcessParams: commonStore.dataProcessParams,
|
||||
loraFinetuneParams: commonStore.loraFinetuneParams
|
||||
};
|
||||
return SaveJson('config.json', data);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user