2023-11-07 11:27:21 +00:00
import React , { FC , useEffect , useRef , useState } from 'react' ;
2023-05-22 02:52:06 +00:00
import { useTranslation } from 'react-i18next' ;
2023-07-03 09:41:47 +00:00
import { Button , Dropdown , Input , Option , Select , Switch , Tab , TabList } from '@fluentui/react-components' ;
import {
ConvertData ,
FileExists ,
2023-07-07 12:16:35 +00:00
GetPyError ,
2023-07-03 09:41:47 +00:00
MergeLora ,
OpenFileFolder ,
WslCommand ,
WslEnable ,
WslInstallUbuntu ,
WslIsEnabled ,
WslStart ,
WslStop
} from '../../wailsjs/go/backend_golang/App' ;
import { toast } from 'react-toastify' ;
import commonStore from '../stores/commonStore' ;
import { observer } from 'mobx-react-lite' ;
import { SelectTabEventHandler } from '@fluentui/react-tabs' ;
2023-07-07 12:16:35 +00:00
import { checkDependencies , toastWithButton } from '../utils' ;
2023-07-03 09:41:47 +00:00
import { Section } from '../components/Section' ;
import { Labeled } from '../components/Labeled' ;
import { ToolTipButton } from '../components/ToolTipButton' ;
import { DataUsageSettings20Regular , Folder20Regular } from '@fluentui/react-icons' ;
import { useNavigate } from 'react-router' ;
import {
CategoryScale ,
Chart as ChartJS ,
Legend ,
LinearScale ,
LineElement ,
PointElement ,
Title ,
Tooltip
} from 'chart.js' ;
import { Line } from 'react-chartjs-2' ;
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types' ;
2023-07-03 13:40:16 +00:00
import { WindowShow } from '../../wailsjs/runtime' ;
2023-07-07 11:10:31 +00:00
import { t } from 'i18next' ;
2023-07-07 13:57:01 +00:00
import { DialogButton } from '../components/DialogButton' ;
2023-11-07 11:27:21 +00:00
import {
DataProcessParameters ,
LoraFinetuneParameters ,
LoraFinetunePrecision ,
TrainNavigationItem
} from '../types/train' ;
2023-05-05 15:23:34 +00:00
2023-07-03 09:41:47 +00:00
ChartJS . register (
CategoryScale ,
LinearScale ,
PointElement ,
LineElement ,
Tooltip ,
Title ,
Legend
) ;
const parseLossData = ( data : string ) = > {
2023-07-09 05:33:06 +00:00
const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\S+)<(\S+),\s+(\S+), loss=(\S+),[\s\S]*\]/g ;
2023-07-03 09:41:47 +00:00
const matches = Array . from ( data . matchAll ( regex ) ) ;
if ( matches . length === 0 )
2023-07-07 11:10:31 +00:00
return false ;
2023-07-03 09:41:47 +00:00
const lastMatch = matches [ matches . length - 1 ] ;
const epoch = parseInt ( lastMatch [ 1 ] ) ;
const loss = parseFloat ( lastMatch [ 8 ] ) ;
commonStore . setChartTitle ( ` Epoch ${ epoch } : ${ lastMatch [ 2 ] } - ${ lastMatch [ 3 ] } / ${ lastMatch [ 4 ] } - ${ lastMatch [ 5 ] } / ${ lastMatch [ 6 ] } - ${ lastMatch [ 7 ] } Loss= ${ loss } ` ) ;
addLossDataToChart ( epoch , loss ) ;
2023-11-30 04:40:16 +00:00
if ( loss > 5 )
toast ( t ( 'Loss is too high, please check the training data, and ensure your gpu driver is up to date.' ) , {
type : 'warning' ,
toastId : 'train_loss_high'
} ) ;
2023-07-07 11:10:31 +00:00
return true ;
2023-07-03 09:41:47 +00:00
} ;
let chartLine : ChartJSOrUndefined < 'line' , ( number | null ) [ ] , string > ;
const addLossDataToChart = ( epoch : number , loss : number ) = > {
const epochIndex = commonStore . chartData . labels ! . findIndex ( l = > l . includes ( epoch . toString ( ) ) ) ;
if ( epochIndex === - 1 ) {
if ( epoch === 0 ) {
commonStore . chartData . labels ! . push ( 'Init' ) ;
commonStore . chartData . datasets [ 0 ] . data = [ . . . commonStore . chartData . datasets [ 0 ] . data , loss ] ;
}
commonStore . chartData . labels ! . push ( 'Epoch ' + epoch . toString ( ) ) ;
commonStore . chartData . datasets [ 0 ] . data = [ . . . commonStore . chartData . datasets [ 0 ] . data , loss ] ;
} else {
if ( chartLine ) {
const newData = [ . . . commonStore . chartData . datasets [ 0 ] . data ] ;
newData [ epochIndex ] = loss ;
chartLine . data . datasets [ 0 ] . data = newData ;
chartLine . update ( ) ;
}
}
commonStore . setChartData ( commonStore . chartData ) ;
} ;
const loraFinetuneParametersOptions : Array < [ key : keyof LoraFinetuneParameters , type : string , name : string ] > = [
[ 'devices' , 'number' , 'Devices' ] ,
[ 'precision' , 'LoraFinetunePrecision' , 'Precision' ] ,
[ 'gradCp' , 'boolean' , 'Gradient Checkpoint' ] ,
[ 'ctxLen' , 'number' , 'Context Length' ] ,
[ 'epochSteps' , 'number' , 'Epoch Steps' ] ,
[ 'epochCount' , 'number' , 'Epoch Count' ] ,
[ 'epochBegin' , 'number' , 'Epoch Begin' ] ,
[ 'epochSave' , 'number' , 'Epoch Save' ] ,
[ 'lrInit' , 'string' , 'Learning Rate Init' ] ,
[ 'lrFinal' , 'string' , 'Learning Rate Final' ] ,
[ 'microBsz' , 'number' , 'Micro Batch Size' ] ,
[ 'accumGradBatches' , 'number' , 'Accumulate Gradient Batches' ] ,
[ 'warmupSteps' , 'number' , 'Warmup Steps' ] ,
[ 'adamEps' , 'string' , 'Adam Epsilon' ] ,
[ 'beta1' , 'number' , 'Beta 1' ] ,
[ 'beta2' , 'number' , 'Beta 2' ] ,
[ 'loraR' , 'number' , 'LoRA R' ] ,
[ 'loraAlpha' , 'number' , 'LoRA Alpha' ] ,
[ 'loraDropout' , 'number' , 'LoRA Dropout' ] ,
2024-05-10 08:41:26 +00:00
[ 'beta1' , 'any' , '' ]
// ['preFfn', 'boolean', 'Pre-FFN'],
// ['headQk', 'boolean', 'Head QK']
2023-07-03 09:41:47 +00:00
] ;
2023-07-07 11:10:31 +00:00
const showError = ( e : any ) = > {
const msg = e . message || e ;
if ( msg === 'wsl not running' ) {
2023-07-09 03:39:44 +00:00
toast ( t ( 'WSL is not running, please retry. If it keeps happening, it means you may be using an outdated version of WSL, run "wsl --update" to update.' ) , { type : 'error' } ) ;
2023-07-07 11:10:31 +00:00
} else {
2023-07-09 03:39:44 +00:00
toast ( t ( msg ) , { type : 'error' , toastId : 'train_error' } ) ;
2023-07-07 11:10:31 +00:00
}
} ;
2024-01-31 13:31:03 +00:00
// error key should be lowercase
2023-07-07 11:10:31 +00:00
const errorsMap = Object . entries ( {
2024-01-31 13:31:03 +00:00
[ 'python3 ./finetune/lora/$modelInfo' . toLowerCase ( ) ] : 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.' ,
2023-07-07 11:10:31 +00:00
'cuda out of memory' : 'VRAM is not enough' ,
'valueerror: high <= 0' : 'Training data is not enough, reduce context length or add more data for training' ,
2023-11-20 12:39:00 +00:00
'+= \'+ptx\'' : 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"' ,
2023-07-09 03:39:44 +00:00
'size mismatch for blocks' : 'Size mismatch for blocks. You are attempting to continue training from the LoRA model, but it does not match the base model. Please set LoRA model to None.' ,
2023-07-07 11:10:31 +00:00
'cuda_home environment variable is not set' : 'Matched CUDA is not installed' ,
'unsupported gpu architecture' : 'Matched CUDA is not installed' ,
2023-11-17 14:37:21 +00:00
'error building extension \'fused_adam\'' : 'Matched CUDA is not installed' ,
2023-11-30 05:01:38 +00:00
'rwkv{version} is not supported' : 'This version of RWKV is not supported yet.' ,
2024-02-04 11:34:36 +00:00
'no such file' : 'Failed to find the base model, please try to change your base model.' ,
2023-11-17 14:37:21 +00:00
'modelinfo is invalid' : 'Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.'
2023-07-07 11:10:31 +00:00
} ) ;
2023-07-03 09:41:47 +00:00
export const wslHandler = ( data : string ) = > {
if ( data ) {
addWslMessage ( data ) ;
2023-07-07 11:10:31 +00:00
const ok = parseLossData ( data ) ;
if ( ! ok )
for ( const [ key , value ] of errorsMap ) {
if ( data . toLowerCase ( ) . includes ( key ) ) {
showError ( value ) ;
return ;
}
}
2023-07-03 09:41:47 +00:00
}
} ;
const addWslMessage = ( message : string ) = > {
const newData = commonStore . wslStdout + '\n' + message ;
let lines = newData . split ( '\n' ) ;
const result = lines . slice ( - 100 ) . join ( '\n' ) ;
commonStore . setWslStdout ( result ) ;
} ;
const TerminalDisplay : FC = observer ( ( ) = > {
const bodyRef = useRef < HTMLDivElement > ( null ) ;
const scrollToBottom = ( ) = > {
if ( bodyRef . current )
bodyRef . current . scrollTop = bodyRef . current . scrollHeight ;
} ;
useEffect ( ( ) = > {
scrollToBottom ( ) ;
} ) ;
return (
< div ref = { bodyRef } className = "grow overflow-x-hidden overflow-y-auto border-gray-500 border-2 rounded-md" >
< div className = "whitespace-pre-line" >
{ commonStore . wslStdout }
< / div >
< / div >
) ;
} ) ;
const Terminal : FC = observer ( ( ) = > {
const { t } = useTranslation ( ) ;
const [ input , setInput ] = useState ( '' ) ;
const handleKeyDown = ( e : any ) = > {
e . stopPropagation ( ) ;
if ( e . keyCode === 13 ) {
e . preventDefault ( ) ;
if ( ! input ) return ;
WslStart ( ) . then ( ( ) = > {
addWslMessage ( 'WSL> ' + input ) ;
setInput ( '' ) ;
2023-08-24 14:48:54 +00:00
WslCommand ( input ) . then ( WindowShow ) . catch ( showError ) ;
2023-07-07 11:10:31 +00:00
} ) . catch ( showError ) ;
2023-07-03 09:41:47 +00:00
}
} ;
return (
< div className = "flex flex-col h-full gap-4" >
< TerminalDisplay / >
< div className = "flex gap-2 items-center" >
WSL :
< Input className = "grow" value = { input } onChange = { ( e ) = > {
setInput ( e . target . value ) ;
} } onKeyDown = { handleKeyDown } > < / Input >
< Button onClick = { ( ) = > {
WslStop ( ) . then ( ( ) = > {
toast ( t ( 'Command Stopped' ) , { type : 'success' } ) ;
2023-07-07 11:10:31 +00:00
} ) . catch ( showError ) ;
2023-07-03 09:41:47 +00:00
} } >
{ t ( 'Stop' ) }
< / Button >
< / div >
< / div >
) ;
} ) ;
const LoraFinetune : FC = observer ( ( ) = > {
2023-05-22 02:52:06 +00:00
const { t } = useTranslation ( ) ;
2023-07-03 09:41:47 +00:00
const navigate = useNavigate ( ) ;
const chartRef = useRef < ChartJSOrUndefined < ' line ' , ( number | null ) [ ] , string > > ( null ) ;
const dataParams = commonStore . dataProcessParams ;
const loraParams = commonStore . loraFinetuneParams ;
if ( chartRef . current )
chartLine = chartRef . current ;
const setDataParams = ( newParams : Partial < DataProcessParameters > ) = > {
commonStore . setDataProcessParams ( {
. . . dataParams ,
. . . newParams
} ) ;
} ;
const setLoraParams = ( newParams : Partial < LoraFinetuneParameters > ) = > {
commonStore . setLoraFinetuneParameters ( {
. . . loraParams ,
. . . newParams
} ) ;
} ;
useEffect ( ( ) = > {
if ( loraParams . baseModel === '' )
setLoraParams ( {
baseModel : commonStore.modelSourceList.find ( m = > m . isComplete ) ? . name || ''
} ) ;
} , [ ] ) ;
2023-07-03 13:40:16 +00:00
const StartLoraFinetune = async ( ) = > {
const ok = await checkDependencies ( navigate ) ;
if ( ! ok )
return ;
2023-07-10 12:44:09 +00:00
const convertedDataPath = './finetune/json2binidx_tool/data/' +
dataParams . dataPath . replace ( /[\/\\]$/ , '' ) . split ( /[\/\\]/ ) . pop ( ) ! . split ( '.' ) [ 0 ] +
'_text_document' ;
2023-07-03 13:40:16 +00:00
if ( ! await FileExists ( convertedDataPath + '.idx' ) ) {
toast ( t ( 'Please convert data first.' ) , { type : 'error' } ) ;
return ;
}
2023-07-03 09:41:47 +00:00
WslIsEnabled ( ) . then ( ( ) = > {
2023-07-03 13:40:16 +00:00
WslStart ( ) . then ( ( ) = > {
setTimeout ( WindowShow , 1000 ) ;
let ctxLen = loraParams . ctxLen ;
if ( dataParams . dataPath === 'finetune/data/sample.jsonl' ) {
ctxLen = 150 ;
toast ( t ( 'You are using sample data for training. For formal training, please make sure to create your own jsonl file.' ) , {
type : 'info' ,
autoClose : 6000
} ) ;
2023-07-03 09:41:47 +00:00
}
commonStore . setChartData ( {
labels : [ ] ,
datasets : [
{
label : 'Loss' ,
data : [ ] ,
borderColor : 'rgb(53, 162, 235)' ,
backgroundColor : 'rgba(53, 162, 235, 0.5)'
}
]
} ) ;
WslCommand ( ` export cnMirror= ${ commonStore . settings . cnMirror ? '1' : '0' } ` +
` && export loadModel=models/ ${ loraParams . baseModel } ` +
2023-07-03 14:28:01 +00:00
` && sed -i 's/ \\ r $ //' finetune/install-wsl-dep-and-train.sh ` +
2023-07-03 09:41:47 +00:00
` && chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` +
( loraParams . baseModel ? ` --load_model models/ ${ loraParams . baseModel } ` : '' ) +
( loraParams . loraLoad ? ` --lora_load lora-models/ ${ loraParams . loraLoad } ` : '' ) +
` --data_file ${ convertedDataPath } ` +
2023-07-03 13:40:16 +00:00
` --ctx_len ${ ctxLen } --epoch_steps ${ loraParams . epochSteps } --epoch_count ${ loraParams . epochCount } ` +
2023-07-03 09:41:47 +00:00
` --epoch_begin ${ loraParams . epochBegin } --epoch_save ${ loraParams . epochSave } ` +
` --micro_bsz ${ loraParams . microBsz } --accumulate_grad_batches ${ loraParams . accumGradBatches } ` +
2024-05-10 08:41:26 +00:00
` --pre_ffn ${ loraParams . preFfn ? '0' : '0' } --head_qk ${ loraParams . headQk ? '0' : '0' } --lr_init ${ loraParams . lrInit } --lr_final ${ loraParams . lrFinal } ` +
2023-07-03 09:41:47 +00:00
` --warmup_steps ${ loraParams . warmupSteps } ` +
` --beta1 ${ loraParams . beta1 } --beta2 ${ loraParams . beta2 } --adam_eps ${ loraParams . adamEps } ` +
` --devices ${ loraParams . devices } --precision ${ loraParams . precision } ` +
` --grad_cp ${ loraParams . gradCp ? '1' : '0' } ` +
2023-07-07 11:10:31 +00:00
` --lora_r ${ loraParams . loraR } --lora_alpha ${ loraParams . loraAlpha } --lora_dropout ${ loraParams . loraDropout } ` ) . catch ( showError ) ;
2023-07-03 09:41:47 +00:00
} ) . catch ( e = > {
const msg = e . message || e ;
if ( msg === 'ubuntu not found' ) {
2023-07-03 13:40:16 +00:00
WindowShow ( ) ;
2023-07-03 09:41:47 +00:00
toastWithButton ( t ( 'Ubuntu is not installed, do you want to install it?' ) , t ( 'Install Ubuntu' ) , ( ) = > {
WslInstallUbuntu ( ) . then ( ( ) = > {
2023-07-03 13:40:16 +00:00
WindowShow ( ) ;
toast ( t ( 'Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button' ) , {
type : 'info' ,
autoClose : 10000
} ) ;
2023-07-03 09:41:47 +00:00
} ) ;
} ) ;
}
} ) ;
} ) . catch ( e = > {
const msg = e . message || e ;
const enableWsl = ( forceMode : boolean ) = > {
2023-07-03 13:40:16 +00:00
WindowShow ( ) ;
2023-07-03 09:41:47 +00:00
toastWithButton ( t ( 'WSL is not enabled, do you want to enable it?' ) , t ( 'Enable WSL' ) , ( ) = > {
WslEnable ( forceMode ) . then ( ( ) = > {
2023-07-03 13:40:16 +00:00
WindowShow ( ) ;
2023-07-03 09:41:47 +00:00
toast ( t ( 'After installation, please restart your computer to enable WSL' ) , {
type : 'info' ,
autoClose : false
} ) ;
2023-07-07 11:10:31 +00:00
} ) . catch ( showError ) ;
2023-07-03 09:41:47 +00:00
} ) ;
} ;
if ( msg === 'wsl is not enabled' ) {
2024-03-24 14:25:02 +00:00
enableWsl ( true ) ;
2023-07-03 09:41:47 +00:00
} else if ( msg . includes ( 'wsl.state: The system cannot find the file' ) ) {
enableWsl ( true ) ;
} else {
2023-07-07 11:10:31 +00:00
showError ( msg ) ;
2023-07-03 09:41:47 +00:00
}
} ) ;
} ;
2023-05-18 12:48:53 +00:00
2023-05-05 15:23:34 +00:00
return (
2023-07-03 09:41:47 +00:00
< div className = "flex flex-col h-full w-full gap-2" >
{ ( commonStore . wslStdout . length > 0 || commonStore . chartData . labels ! . length !== 0 ) &&
< div className = "flex" style = { { height : '35%' } } >
{ commonStore . wslStdout . length > 0 && commonStore . chartData . labels ! . length === 0 && < TerminalDisplay / > }
{ commonStore . chartData . labels ! . length !== 0 &&
< Line ref = { chartRef } data = { commonStore . chartData } options = { {
responsive : true ,
showLine : true ,
plugins : {
legend : {
position : 'right' ,
align : 'start'
} ,
title : {
display : true ,
text : commonStore.chartTitle
}
} ,
scales : {
y : {
beginAtZero : true
}
} ,
maintainAspectRatio : false
} } style = { { width : '100%' } } / > }
< / div >
}
< div >
< Section
title = { t ( 'Data Process' ) }
content = {
< div className = "flex flex-col gap-2" >
2023-07-07 13:57:01 +00:00
< div className = "flex gap-2 items-center" >
{ t ( 'Data Path' ) }
< Input className = "grow" style = { { minWidth : 0 } } value = { dataParams . dataPath }
onChange = { ( e , data ) = > {
setDataParams ( { dataPath : data.value } ) ;
} } / >
< DialogButton text = { t ( 'Help' ) } title = { t ( 'Help' ) } markdown
2024-03-26 13:25:13 +00:00
content = { t ( 'The data path should be a directory or a file in jsonl format (more formats will be supported in the future).\n\n' +
2023-07-07 13:57:01 +00:00
'When you provide a directory path, all the txt files within that directory will be automatically converted into training data. ' +
'This is commonly used for large-scale training in writing, code generation, or knowledge bases.\n\n' +
2023-10-26 06:08:16 +00:00
'The jsonl format file can be referenced at https://github.com/josStorer/RWKV-Runner/blob/master/finetune/data/sample.jsonl.\n' +
2023-07-07 13:57:01 +00:00
'You can also write it similar to OpenAI\'s playground format, as shown in https://platform.openai.com/playground/p/default-chat.\n' +
'Even for multi-turn conversations, they must be written in a single line using `\\n` to indicate line breaks. ' +
'If they are different dialogues or topics, they should be written in separate lines.' ) } / >
< ToolTipButton desc = { t ( 'Open Folder' ) } icon = { < Folder20Regular / > } onClick = { ( ) = > {
2024-02-02 14:00:01 +00:00
OpenFileFolder ( dataParams . dataPath ) ;
2023-07-07 13:57:01 +00:00
} } / >
< / div >
2023-07-03 09:41:47 +00:00
< div className = "flex gap-2 items-center" >
{ t ( 'Vocab Path' ) }
< Input className = "grow" style = { { minWidth : 0 } } value = { dataParams . vocabPath }
onChange = { ( e , data ) = > {
setDataParams ( { vocabPath : data.value } ) ;
} } / >
2023-07-07 13:57:01 +00:00
< Button appearance = "secondary" onClick = { async ( ) = > {
2023-07-07 11:10:31 +00:00
const ok = await checkDependencies ( navigate ) ;
if ( ! ok )
return ;
2023-07-07 13:57:01 +00:00
const outputPrefix = './finetune/json2binidx_tool/data/' +
dataParams . dataPath . replace ( /[\/\\]$/ , '' ) . split ( /[\/\\]/ ) . pop ( ) ! . split ( '.' ) [ 0 ] ;
2023-07-10 12:44:09 +00:00
ConvertData ( commonStore . settings . customPythonPath ,
dataParams . dataPath . replaceAll ( '\\' , '/' ) ,
outputPrefix ,
dataParams . vocabPath ) . then ( async ( ) = > {
2023-07-07 12:16:35 +00:00
if ( ! await FileExists ( outputPrefix + '_text_document.idx' ) ) {
2024-02-04 16:25:04 +00:00
if ( commonStore . platform === 'windows' || commonStore . platform === 'linux' )
toast ( t ( 'Failed to convert data' ) + ' - ' + await GetPyError ( ) , { type : 'error' } ) ;
2023-07-07 12:16:35 +00:00
} else {
toast ( t ( 'Convert Data successfully' ) , { type : 'success' } ) ;
}
2023-07-07 11:10:31 +00:00
} ) . catch ( showError ) ;
2023-07-03 09:41:47 +00:00
} } > { t ( 'Convert' ) } < / Button >
< / div >
< / div >
}
/ >
< / div >
< Section
title = { t ( 'Train Parameters' ) }
content = {
< div className = "grid grid-cols-1 sm:grid-cols-2 gap-2" >
< div className = "flex gap-2 items-center" >
2024-05-10 07:38:21 +00:00
< div className = "shrink-0" >
{ t ( 'Base Model' ) }
< / div >
2023-07-03 09:41:47 +00:00
< Select style = { { minWidth : 0 } } className = "grow"
value = { loraParams . baseModel }
onChange = { ( e , data ) = > {
setLoraParams ( {
baseModel : data.value
} ) ;
} } >
{ commonStore . modelSourceList . map ( ( modelItem , index ) = >
modelItem . isComplete && < option key = { index } value = { modelItem . name } > { modelItem . name } < / option >
) }
< / Select >
< ToolTipButton desc = { t ( 'Manage Models' ) } icon = { < DataUsageSettings20Regular / > } onClick = { ( ) = > {
navigate ( { pathname : '/models' } ) ;
} } / >
< / div >
< div className = "flex gap-2 items-center" >
{ t ( 'LoRA Model' ) }
< Select style = { { minWidth : 0 } } className = "grow"
value = { loraParams . loraLoad }
onChange = { ( e , data ) = > {
setLoraParams ( {
loraLoad : data.value
} ) ;
} } >
< option value = "" > { t ( 'None' ) } < / option >
{ commonStore . loraModels . map ( ( name , index ) = >
< option key = { index } value = { name } > { name } < / option >
) }
< / Select >
2023-07-03 13:40:16 +00:00
< Button onClick = { async ( ) = > {
const ok = await checkDependencies ( navigate ) ;
if ( ! ok )
return ;
if ( loraParams . loraLoad ) {
2023-07-07 12:16:35 +00:00
const outputPath = ` models/ ${ loraParams . baseModel } -LoRA- ${ loraParams . loraLoad } ` ;
2023-12-06 15:17:13 +00:00
MergeLora ( commonStore . settings . customPythonPath , ! ! commonStore . monitorData && commonStore . monitorData . totalVram !== 0 , loraParams . loraAlpha ,
2023-07-03 13:40:16 +00:00
'models/' + loraParams . baseModel , 'lora-models/' + loraParams . loraLoad ,
2023-07-07 12:16:35 +00:00
outputPath ) . then ( async ( ) = > {
if ( ! await FileExists ( outputPath ) ) {
2023-12-07 15:26:39 +00:00
if ( commonStore . platform === 'windows' || commonStore . platform === 'linux' )
toast ( t ( 'Failed to merge model' ) + ' - ' + await GetPyError ( ) , { type : 'error' } ) ;
2023-07-07 12:16:35 +00:00
} else {
toast ( t ( 'Merge model successfully' ) , { type : 'success' } ) ;
}
2023-07-07 11:10:31 +00:00
} ) . catch ( showError ) ;
2023-07-03 13:40:16 +00:00
} else {
toast ( t ( 'Please select a LoRA model' ) , { type : 'info' } ) ;
}
2023-07-03 09:41:47 +00:00
} } > { t ( 'Merge Model' ) } < / Button >
< / div >
{
loraFinetuneParametersOptions . map ( ( [ key , type , name ] , index ) = > {
return (
< Labeled key = { index } label = { t ( name ) } content = {
type === 'number' ?
< Input type = "number" className = "grow" value = { loraParams [ key ] . toString ( ) }
onChange = { ( e , data ) = > {
setLoraParams ( {
[ key ] : Number ( data . value )
} ) ;
} } / > :
type === 'boolean' ?
< Switch className = "grow" checked = { loraParams [ key ] as boolean }
onChange = { ( e , data ) = > {
setLoraParams ( {
[ key ] : data . checked
} ) ;
} } / > :
type === 'string' ?
< Input className = "grow" value = { loraParams [ key ] . toString ( ) }
onChange = { ( e , data ) = > {
setLoraParams ( {
[ key ] : data . value
} ) ;
} } / > :
type === 'LoraFinetunePrecision' ?
< Dropdown style = { { minWidth : 0 } } className = "grow"
value = { loraParams [ key ] . toString ( ) }
selectedOptions = { [ loraParams [ key ] . toString ( ) ] }
onOptionSelect = { ( _ , data ) = > {
if ( data . optionText ) {
setLoraParams ( {
precision : data.optionText as LoraFinetunePrecision
} ) ;
}
} }
>
< Option > bf16 < / Option >
< Option > fp16 < / Option >
< Option > tf32 < / Option >
< / Dropdown >
: < div / >
} / >
) ;
} )
}
< / div >
}
/ >
< div className = "grow" / >
< div className = "flex gap-2" >
< div className = "grow" / >
< Button appearance = "secondary" size = "large" onClick = { ( ) = > {
WslStop ( ) . then ( ( ) = > {
toast ( t ( 'Command Stopped' ) , { type : 'success' } ) ;
2023-07-07 11:10:31 +00:00
} ) . catch ( showError ) ;
2023-07-03 09:41:47 +00:00
} } > { t ( 'Stop' ) } < / Button >
< Button appearance = "primary" size = "large" onClick = { StartLoraFinetune } > { t ( 'Train' ) } < / Button >
< / div >
2023-05-05 15:23:34 +00:00
< / div >
) ;
2023-07-03 09:41:47 +00:00
} ) ;
const pages : { [ label : string ] : TrainNavigationItem } = {
'LoRA Finetune' : {
element : < LoraFinetune / >
} ,
WSL : {
element : < Terminal / >
}
} ;
2023-11-07 11:27:21 +00:00
const Train : FC = ( ) = > {
2023-07-03 09:41:47 +00:00
const { t } = useTranslation ( ) ;
const [ tab , setTab ] = useState ( 'LoRA Finetune' ) ;
const selectTab : SelectTabEventHandler = ( e , data ) = >
typeof data . value === 'string' ? setTab ( data . value ) : null ;
return < div className = "flex flex-col gap-2 w-full h-full" >
< TabList
size = "small"
appearance = "subtle"
selectedValue = { tab }
onTabSelect = { selectTab }
>
{ Object . entries ( pages ) . map ( ( [ label ] ) = > (
< Tab key = { label } value = { label } >
{ t ( label ) }
< / Tab >
) ) }
< / TabList >
< div className = "grow overflow-hidden" >
{ pages [ tab ] . element }
< / div >
< / div > ;
2023-05-05 15:23:34 +00:00
} ;
2023-11-07 11:27:21 +00:00
export default Train ;