From 53502a8c3df00a80c22d5ea4847531b126189d44 Mon Sep 17 00:00:00 2001 From: josc146 Date: Tue, 16 May 2023 13:22:57 +0800 Subject: [PATCH] preliminary usable features --- frontend/src/components/RunButton.tsx | 61 +++++++++++++++++++ frontend/src/pages/Configs.tsx | 42 ++++++++++--- frontend/src/pages/Home.tsx | 85 ++++----------------------- frontend/src/stores/commonStore.ts | 35 +++++++++-- 4 files changed, 137 insertions(+), 86 deletions(-) create mode 100644 frontend/src/components/RunButton.tsx diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx new file mode 100644 index 0000000..e49d2c2 --- /dev/null +++ b/frontend/src/components/RunButton.tsx @@ -0,0 +1,61 @@ +import React, {FC} from 'react'; +import commonStore, {ModelStatus} from '../stores/commonStore'; +import {StartServer} from '../../wailsjs/go/backend_golang/App'; +import {Button} from '@fluentui/react-components'; +import {observer} from 'mobx-react-lite'; + +const mainButtonText = { + [ModelStatus.Offline]: 'Run', + [ModelStatus.Starting]: 'Starting', + [ModelStatus.Loading]: 'Loading', + [ModelStatus.Working]: 'Stop' +}; + +const onClickMainButton = async () => { + if (commonStore.modelStatus === ModelStatus.Offline) { + commonStore.setModelStatus(ModelStatus.Starting); + StartServer(commonStore.getStrategy(), `models\\${commonStore.getCurrentModelConfig().modelParameters.modelName}`); + + let timeoutCount = 5; + let loading = false; + const intervalId = setInterval(() => { + fetch('http://127.0.0.1:8000') + .then(r => { + if (r.ok && !loading) { + clearInterval(intervalId); + commonStore.setModelStatus(ModelStatus.Loading); + loading = true; + fetch('http://127.0.0.1:8000/update-config', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({}) + }).then(async (r) => { + if (r.ok) + commonStore.setModelStatus(ModelStatus.Working); + }); + } + }).catch(() => { + if (timeoutCount <= 0) { + clearInterval(intervalId); + commonStore.setModelStatus(ModelStatus.Offline); + } + }); + + timeoutCount--; + }, 1000); + } else { + commonStore.setModelStatus(ModelStatus.Offline); + fetch('http://127.0.0.1:8000/exit', {method: 'POST'}); + } +}; + +export const RunButton: FC = observer(() => { + return ( + + ); +}); diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index 63c721c..4a3e18b 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -1,16 +1,17 @@ -import {Button, Dropdown, Input, Label, Option, Select, Slider, Switch} from '@fluentui/react-components'; +import {Dropdown, Input, Label, Option, Select, Switch} from '@fluentui/react-components'; import {AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular} from '@fluentui/react-icons'; import React, {FC} from 'react'; import {Section} from '../components/Section'; import {Labeled} from '../components/Labeled'; import {ToolTipButton} from '../components/ToolTipButton'; -import commonStore, {ApiParameters, ModelParameters} from '../stores/commonStore'; +import commonStore, {ApiParameters, Device, ModelParameters, Precision} from '../stores/commonStore'; import {observer} from 'mobx-react-lite'; import {toast} from 'react-toastify'; import {ValuedSlider} from '../components/ValuedSlider'; import {NumberInput} from '../components/NumberInput'; import {Page} from '../components/Page'; import {useNavigate} from 'react-router'; +import {RunButton} from '../components/RunButton'; export const Configs: FC = observer(() => { const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex); @@ -164,30 +165,57 @@ export const Configs: FC = observer(() => { }/> + { + if (data.optionText) { + setSelectedConfigModelParams({ + device: data.optionText as Device + }); + } + }}> }/> + { + if (data.optionText) { + setSelectedConfigModelParams({ + precision: data.optionText as Precision + }); + } + }}> }/> + { + setSelectedConfigModelParams({ + streamedLayers: data.value + }); + }}/> }/> + { + setSelectedConfigModelParams({ + enableHighPrecisionForLastLayer: data.checked + }); + }}/> }/> } />
- +
}/> diff --git a/frontend/src/pages/Home.tsx b/frontend/src/pages/Home.tsx index 837d19d..d815df1 100644 --- a/frontend/src/pages/Home.tsx +++ b/frontend/src/pages/Home.tsx @@ -1,4 +1,4 @@ -import {Button, CompoundButton, Dropdown, Link, Option, Text} from '@fluentui/react-components'; +import {CompoundButton, Dropdown, Link, Option, Text} from '@fluentui/react-components'; import React, {FC, ReactElement} from 'react'; import banner from '../assets/images/banner.jpg'; import { @@ -8,9 +8,9 @@ import { Storage20Regular } from '@fluentui/react-icons'; import {useNavigate} from 'react-router'; -import commonStore, {ModelStatus} from '../stores/commonStore'; +import commonStore from '../stores/commonStore'; import {observer} from 'mobx-react-lite'; -import {StartServer} from '../../wailsjs/go/backend_golang/App'; +import {RunButton} from '../components/RunButton'; type NavCard = { label: string; @@ -19,7 +19,7 @@ type NavCard = { icon: ReactElement; }; -export const navCards: NavCard[] = [ +const navCards: NavCard[] = [ { label: 'Chat', desc: 'Go to chat page', @@ -46,62 +46,13 @@ export const navCards: NavCard[] = [ } ]; -const mainButtonText = { - [ModelStatus.Offline]: 'Run', - [ModelStatus.Starting]: 'Starting', - [ModelStatus.Loading]: 'Loading', - [ModelStatus.Working]: 'Stop' -}; - export const Home: FC = observer(() => { - const [selectedConfig, setSelectedConfig] = React.useState('RWKV-3B-4G MEM'); - const navigate = useNavigate(); const onClickNavCard = (path: string) => { navigate({pathname: path}); }; - const onClickMainButton = async () => { - if (commonStore.modelStatus === ModelStatus.Offline) { - commonStore.setModelStatus(ModelStatus.Starting); - StartServer('cuda fp16', 'models\\RWKV-4-Raven-1B5-v8-Eng-20230408-ctx4096.pth'); - - let timeoutCount = 5; - let loading = false; - const intervalId = setInterval(() => { - fetch('http://127.0.0.1:8000') - .then(r => { - if (r.ok && !loading) { - clearInterval(intervalId); - commonStore.setModelStatus(ModelStatus.Loading); - loading = true; - fetch('http://127.0.0.1:8000/update-config', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({}) - }).then(async (r) => { - if (r.ok) - commonStore.setModelStatus(ModelStatus.Working); - }); - } - }).catch(() => { - if (timeoutCount <= 0) { - clearInterval(intervalId); - commonStore.setModelStatus(ModelStatus.Offline); - } - }); - - timeoutCount--; - }, 1000); - } else { - commonStore.setModelStatus(ModelStatus.Offline); - fetch('http://127.0.0.1:8000/exit', {method: 'POST'}); - } - }; - return (
@@ -128,30 +79,18 @@ export const Home: FC = observer(() => {
- { if (data.optionValue) - setSelectedConfig(data.optionValue); + commonStore.setCurrentConfigIndex(Number(data.optionValue)); }}> - - - - + {commonStore.modelConfigs.map((config, index) => + + )} - +
diff --git a/frontend/src/stores/commonStore.ts b/frontend/src/stores/commonStore.ts index 7ea3a3b..1ac74c9 100644 --- a/frontend/src/stores/commonStore.ts +++ b/frontend/src/stores/commonStore.ts @@ -30,12 +30,16 @@ export type ApiParameters = { countPenalty: number; } +export type Device = 'CPU' | 'CUDA'; +export type Precision = 'fp16' | 'int8' | 'fp32'; + export type ModelParameters = { // different models can not have the same name modelName: string; - device: string; - precision: string; + device: Device; + precision: Precision; streamedLayers: number; + maxStreamedLayers: number; enableHighPrecisionForLastLayer: boolean; } @@ -58,10 +62,11 @@ export const defaultModelConfigs: ModelConfig[] = [ countPenalty: 0 }, modelParameters: { - modelName: '124M', - device: 'CPU', - precision: 'fp32', - streamedLayers: 1, + modelName: 'RWKV-4-Raven-1B5-v11-Eng99%-Other1%-20230425-ctx4096.pth', + device: 'CUDA', + precision: 'fp16', + streamedLayers: 25, + maxStreamedLayers: 25, enableHighPrecisionForLastLayer: false } } @@ -86,6 +91,24 @@ class CommonStore { }); } + getStrategy(modelConfig: ModelConfig | undefined = undefined) { + let params: ModelParameters; + if (modelConfig) params = modelConfig.modelParameters; + else params = this.getCurrentModelConfig().modelParameters; + let strategy = ''; + strategy += (params.device === 'CPU' ? 'cpu' : 'cuda') + ' '; + strategy += (params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32'); + if (params.streamedLayers < params.maxStreamedLayers) + strategy += ` *${params.streamedLayers}+`; + if (params.enableHighPrecisionForLastLayer) + strategy += ' -> cpu fp32 *1'; + return strategy; + } + + getCurrentModelConfig = () => { + return this.modelConfigs[this.currentModelConfigIndex]; + }; + setModelStatus = (status: ModelStatus) => { this.modelStatus = status; };