diff --git a/backend-python/main.py b/backend-python/main.py index 20d3837..5acd5dd 100644 --- a/backend-python/main.py +++ b/backend-python/main.py @@ -42,7 +42,7 @@ def read_root(): @app.post("/exit") -def read_root(): +def exit(): parent_pid = os.getpid() parent = psutil.Process(parent_pid) for child in parent.children(recursive=True): diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index 7dd4307..767ced4 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -26,6 +26,8 @@ class CompletionBody(BaseModel): @router.post("/chat/completions") async def completions(body: CompletionBody, request: Request): model = global_var.get(global_var.Model) + if (model is None): + raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") question = body.messages[-1] if question.role == 'user': diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index d65613a..e6ddb56 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -1,7 +1,7 @@ import pathlib import sys -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, Response, status from pydantic import BaseModel from langchain.llms import RWKV from utils.rwkv import * @@ -22,9 +22,10 @@ class UpdateConfigBody(BaseModel): @router.post("/update-config") -def update_config(body: UpdateConfigBody): +def update_config(body: UpdateConfigBody, response: Response): if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading): - return "loading" + response.status_code = status.HTTP_304_NOT_MODIFIED + return global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline) global_var.set(global_var.Model, None) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index f9497a4..bd53bd1 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -10,10 +10,13 @@ "dependencies": { "@fluentui/react-components": "^9.19.1", "@fluentui/react-icons": "^2.0.201", + "mobx": "^6.9.0", + "mobx-react-lite": "^3.4.3", "react": "^18.2.0", "react-dom": "^18.2.0", "react-router": "^6.11.0", "react-router-dom": "^6.11.0", + "react-toastify": "^9.1.2", "usehooks-ts": "^2.9.1" }, "devDependencies": { @@ -1921,6 +1924,14 @@ "node": ">=12" } }, + "node_modules/clsx": { + "version": "1.2.1", + "resolved": "https://registry.npmmirror.com/clsx/-/clsx-1.2.1.tgz", + "integrity": "sha512-EcR6r5a8bj6pu3ycsa/E/cKVGuTgZJZdsyUYHOksG/UHIiKfjxzRxYJpyVBwYaQeOvghal9fcc4PidlgzugAQg==", + "engines": { + "node": ">=6" + } + }, "node_modules/color-convert": { "version": "1.9.3", "resolved": "https://registry.npmmirror.com/color-convert/-/color-convert-1.9.3.tgz", @@ -2784,6 +2795,28 @@ "node": "*" } }, + "node_modules/mobx": { + "version": "6.9.0", + "resolved": "https://registry.npmmirror.com/mobx/-/mobx-6.9.0.tgz", + "integrity": "sha512-HdKewQEREEJgsWnErClfbFoVebze6rGazxFLU/XUyrII8dORfVszN1V0BMRnQSzcgsNNtkX8DHj3nC6cdWE9YQ==" + }, + "node_modules/mobx-react-lite": { + "version": "3.4.3", + "resolved": "https://registry.npmmirror.com/mobx-react-lite/-/mobx-react-lite-3.4.3.tgz", + "integrity": "sha512-NkJREyFTSUXR772Qaai51BnE1voWx56LOL80xG7qkZr6vo8vEaLF3sz1JNUVh+rxmUzxYaqOhfuxTfqUh0FXUg==", + "peerDependencies": { + "mobx": "^6.1.0", + "react": "^16.8.0 || ^17 || ^18" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + }, + "react-native": { + "optional": true + } + } + }, "node_modules/ms": { "version": "2.1.2", "resolved": "https://registry.npmmirror.com/ms/-/ms-2.1.2.tgz", @@ -3081,6 +3114,18 @@ "react-dom": ">=16.8" } }, + "node_modules/react-toastify": { + "version": "9.1.2", + "resolved": "https://registry.npmmirror.com/react-toastify/-/react-toastify-9.1.2.tgz", + "integrity": "sha512-PBfzXO5jMGEtdYR5jxrORlNZZe/EuOkwvwKijMatsZZm8IZwLj01YvobeJYNjFcA6uy6CVrx2fzL9GWbhWPTDA==", + "dependencies": { + "clsx": "^1.1.1" + }, + "peerDependencies": { + "react": ">=16", + "react-dom": ">=16" + } + }, "node_modules/read-cache": { "version": "1.0.0", "resolved": "https://registry.npmmirror.com/read-cache/-/read-cache-1.0.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index 6fd02ae..3560c22 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -11,10 +11,13 @@ "dependencies": { "@fluentui/react-components": "^9.19.1", "@fluentui/react-icons": "^2.0.201", + "mobx": "^6.9.0", + "mobx-react-lite": "^3.4.3", "react": "^18.2.0", "react-dom": "^18.2.0", "react-router": "^6.11.0", "react-router-dom": "^6.11.0", + "react-toastify": "^9.1.2", "usehooks-ts": "^2.9.1" }, "devDependencies": { diff --git a/frontend/src/pages/Home.tsx b/frontend/src/pages/Home.tsx index c593906..0211362 100644 --- a/frontend/src/pages/Home.tsx +++ b/frontend/src/pages/Home.tsx @@ -8,6 +8,8 @@ import { Storage20Regular } from '@fluentui/react-icons'; import {useNavigate} from 'react-router'; +import commonStore, {ModelStatus} from '../stores/commonStore'; +import {observer} from 'mobx-react-lite'; import {StartServer} from '../../wailsjs/go/backend_golang/App'; type NavCard = { @@ -44,7 +46,14 @@ export const navCards: NavCard[] = [ } ]; -export const Home: FC = () => { +const mainButtonText = { + [ModelStatus.Offline]: 'Run', + [ModelStatus.Starting]: 'Starting', + [ModelStatus.Loading]: 'Loading', + [ModelStatus.Working]: 'Stop' +}; + +export const Home: FC = observer(() => { const [selectedConfig, setSelectedConfig] = React.useState('RWKV-3B-4G MEM'); const navigate = useNavigate(); @@ -53,6 +62,46 @@ export const Home: FC = () => { navigate({pathname: path}); }; + const onClickMainButton = async () => { + if (commonStore.modelStatus === ModelStatus.Offline) { + commonStore.updateModelStatus(ModelStatus.Starting); + StartServer('cuda fp16i8', 'E:\\RWKV-4-Raven-3B-v10-Eng49%-Chn50%-Other1%-20230419-ctx4096.pth'); + + let timeoutCount = 5; + let loading = false; + const intervalId = setInterval(() => { + fetch('http://127.0.0.1:8000') + .then(r => { + if (r.ok && !loading) { + clearInterval(intervalId); + commonStore.updateModelStatus(ModelStatus.Loading); + loading = true; + fetch('http://127.0.0.1:8000/update-config', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({}) + }).then(async (r) => { + if (r.ok) + commonStore.updateModelStatus(ModelStatus.Working); + }); + } + }).catch(() => { + if (timeoutCount <= 0) { + clearInterval(intervalId); + commonStore.updateModelStatus(ModelStatus.Offline); + } + }); + + timeoutCount--; + }, 1000); + } else { + commonStore.updateModelStatus(ModelStatus.Offline); + fetch('http://127.0.0.1:8000/exit', {method: 'POST'}); + } + }; + return (