diff --git a/.gitignore b/.gitignore
index 9aeb554..7e6c48b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,4 +6,5 @@ __pycache__
.vs
package.json.md5
cache.json
-stats.html
\ No newline at end of file
+stats.html
+models
\ No newline at end of file
diff --git a/backend-golang/config.go b/backend-golang/config.go
deleted file mode 100644
index c9bce76..0000000
--- a/backend-golang/config.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package backend_golang
-
-import (
- "encoding/json"
- "os"
-)
-
-func (a *App) SaveJson(fileName string, jsonData interface{}) string {
- text, err := json.MarshalIndent(jsonData, "", " ")
- if err != nil {
- return err.Error()
- }
-
- if err := os.WriteFile(fileName, text, 0644); err != nil {
- return err.Error()
- }
- return ""
-}
diff --git a/backend-golang/file.go b/backend-golang/file.go
new file mode 100644
index 0000000..73c62a6
--- /dev/null
+++ b/backend-golang/file.go
@@ -0,0 +1,63 @@
+package backend_golang
+
+import (
+ "encoding/json"
+ "os"
+
+ "github.com/cavaliergopher/grab/v3"
+)
+
+func (a *App) SaveJson(fileName string, jsonData any) error {
+ text, err := json.MarshalIndent(jsonData, "", " ")
+ if err != nil {
+ return err
+ }
+
+ if err := os.WriteFile(fileName, text, 0644); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *App) ReadJson(fileName string) (any, error) {
+ file, err := os.ReadFile(fileName)
+ if err != nil {
+ return nil, err
+ }
+
+ var data any
+ err = json.Unmarshal(file, &data)
+ if err != nil {
+ return nil, err
+ }
+
+ return data, nil
+}
+
+func (a *App) FileExists(fileName string) (bool, error) {
+ _, err := os.Stat(fileName)
+ if err == nil {
+ return true, nil
+ }
+ return false, err
+}
+
+func (a *App) FileInfo(fileName string) (any, error) {
+ info, err := os.Stat(fileName)
+ if err != nil {
+ return nil, err
+ }
+ return map[string]any{
+ "name": info.Name(),
+ "size": info.Size(),
+ "isDir": info.IsDir(),
+ }, nil
+}
+
+func (a *App) DownloadFile(path string, url string) error {
+ _, err := grab.Get(path, url)
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go
index c883acc..8f6a01b 100644
--- a/backend-golang/rwkv.go
+++ b/backend-golang/rwkv.go
@@ -4,12 +4,12 @@ import (
"os/exec"
)
-func (a *App) StartServer(strategy string, modelPath string) string {
+func (a *App) StartServer(strategy string, modelPath string) (string, error) {
//cmd := exec.Command(`explorer`, `/select,`, `e:\RWKV-4-Raven-7B-v10-Eng49%25-Chn50%25-Other1%25-20230420-ctx4096.pth`)
cmd := exec.Command("cmd-helper", "python", "./backend-python/main.py", strategy, modelPath)
out, err := cmd.CombinedOutput()
if err != nil {
- return err.Error()
+ return "", err
}
- return string(out)
+ return string(out), nil
}
diff --git a/frontend/src/components/ToolTipButton.tsx b/frontend/src/components/ToolTipButton.tsx
index be5a98d..ab0c5fd 100644
--- a/frontend/src/components/ToolTipButton.tsx
+++ b/frontend/src/components/ToolTipButton.tsx
@@ -1,10 +1,17 @@
-import React, {FC, ReactElement} from 'react';
+import React, {FC, MouseEventHandler, ReactElement} from 'react';
import {Button, Tooltip} from '@fluentui/react-components';
-export const ToolTipButton: FC<{ text?: string, desc: string, icon?: ReactElement }> = ({text, desc, icon}) => {
+export const ToolTipButton: FC<{
+ text?: string, desc: string, icon?: ReactElement, onClick?: MouseEventHandler
+}> = ({
+ text,
+ desc,
+ icon,
+ onClick
+ }) => {
return (
- {text}
+ {text}
);
};
diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx
index d3b0a06..39eb4ab 100644
--- a/frontend/src/main.tsx
+++ b/frontend/src/main.tsx
@@ -1,15 +1,18 @@
-import React from 'react'
-import {createRoot} from 'react-dom/client'
-import './style.css'
-import App from './App'
+import React from 'react';
+import {createRoot} from 'react-dom/client';
+import './style.css';
+import App from './App';
import {HashRouter} from 'react-router-dom';
+import {startup} from './startup';
-const container = document.getElementById('root')
+startup().then(() => {
+ const container = document.getElementById('root');
-const root = createRoot(container!)
+ const root = createRoot(container!);
-root.render(
+ root.render(
-
+
-)
+ );
+});
diff --git a/frontend/src/pages/Home.tsx b/frontend/src/pages/Home.tsx
index 0211362..082c5aa 100644
--- a/frontend/src/pages/Home.tsx
+++ b/frontend/src/pages/Home.tsx
@@ -1,6 +1,6 @@
import {Button, CompoundButton, Dropdown, Link, Option, Text} from '@fluentui/react-components';
import React, {FC, ReactElement} from 'react';
-import Banner from '../assets/images/banner.jpg';
+import banner from '../assets/images/banner.jpg';
import {
Chat20Regular,
DataUsageSettings20Regular,
@@ -64,7 +64,7 @@ export const Home: FC = observer(() => {
const onClickMainButton = async () => {
if (commonStore.modelStatus === ModelStatus.Offline) {
- commonStore.updateModelStatus(ModelStatus.Starting);
+ commonStore.setModelStatus(ModelStatus.Starting);
StartServer('cuda fp16i8', 'E:\\RWKV-4-Raven-3B-v10-Eng49%-Chn50%-Other1%-20230419-ctx4096.pth');
let timeoutCount = 5;
@@ -74,7 +74,7 @@ export const Home: FC = observer(() => {
.then(r => {
if (r.ok && !loading) {
clearInterval(intervalId);
- commonStore.updateModelStatus(ModelStatus.Loading);
+ commonStore.setModelStatus(ModelStatus.Loading);
loading = true;
fetch('http://127.0.0.1:8000/update-config', {
method: 'POST',
@@ -84,27 +84,27 @@ export const Home: FC = observer(() => {
body: JSON.stringify({})
}).then(async (r) => {
if (r.ok)
- commonStore.updateModelStatus(ModelStatus.Working);
+ commonStore.setModelStatus(ModelStatus.Working);
});
}
}).catch(() => {
if (timeoutCount <= 0) {
clearInterval(intervalId);
- commonStore.updateModelStatus(ModelStatus.Offline);
+ commonStore.setModelStatus(ModelStatus.Offline);
}
});
timeoutCount--;
}, 1000);
} else {
- commonStore.updateModelStatus(ModelStatus.Offline);
+ commonStore.setModelStatus(ModelStatus.Offline);
fetch('http://127.0.0.1:8000/exit', {method: 'POST'});
}
};
return (
-
+
Introduction
diff --git a/frontend/src/pages/Models.tsx b/frontend/src/pages/Models.tsx
index 18cf31a..bb5f0ed 100644
--- a/frontend/src/pages/Models.tsx
+++ b/frontend/src/pages/Models.tsx
@@ -1,4 +1,4 @@
-import React, {FC, useEffect} from 'react';
+import React, {FC} from 'react';
import {
createTableColumn,
DataGrid,
@@ -12,40 +12,18 @@ import {
Text,
Textarea
} from '@fluentui/react-components';
-import {EditRegular} from '@fluentui/react-icons/lib/fonts';
import {ToolTipButton} from '../components/ToolTipButton';
-import {ArrowClockwise20Regular} from '@fluentui/react-icons';
+import {ArrowClockwise20Regular, ArrowDownload20Regular, Open20Regular} from '@fluentui/react-icons';
+import {observer} from 'mobx-react-lite';
+import commonStore, {ModelSourceItem} from '../stores/commonStore';
+import {BrowserOpenURL} from '../../wailsjs/runtime';
+import {DownloadFile} from '../../wailsjs/go/backend_golang/App';
-type Operation = {
- icon: JSX.Element;
- desc: string
-}
-
-type Item = {
- filename: string;
- desc: string;
- size: number;
- lastUpdated: number;
- actions: Operation[];
- isLocal: boolean;
-};
-
-const items: Item[] = [
- {
- filename: 'RWKV-4-Raven-14B-v11x-Eng99%-Other1%-20230501-ctx8192.pth',
- desc: 'Mainly English language corpus',
- size: 28297309490,
- lastUpdated: 1,
- actions: [{icon:
, desc: 'Edit'}],
- isLocal: false
- }
-];
-
-const columns: TableColumnDefinition
- [] = [
- createTableColumn
- ({
+const columns: TableColumnDefinition
[] = [
+ createTableColumn({
columnId: 'file',
compare: (a, b) => {
- return a.filename.localeCompare(b.filename);
+ return a.name.localeCompare(b.name);
},
renderHeaderCell: () => {
return 'File';
@@ -53,15 +31,15 @@ const columns: TableColumnDefinition- [] = [
renderCell: (item) => {
return (
- {item.filename}
+ {item.name}
);
}
}),
- createTableColumn- ({
+ createTableColumn
({
columnId: 'desc',
compare: (a, b) => {
- return a.desc.localeCompare(b.desc);
+ return a.desc['en'].localeCompare(b.desc['en']);
},
renderHeaderCell: () => {
return 'Desc';
@@ -69,12 +47,12 @@ const columns: TableColumnDefinition- [] = [
renderCell: (item) => {
return (
- {item.desc}
+ {item.desc['en']}
);
}
}),
- createTableColumn- ({
+ createTableColumn
({
columnId: 'size',
compare: (a, b) => {
return a.size - b.size;
@@ -90,10 +68,14 @@ const columns: TableColumnDefinition- [] = [
);
}
}),
- createTableColumn
- ({
+ createTableColumn
({
columnId: 'lastUpdated',
compare: (a, b) => {
- return a.lastUpdated - b.lastUpdated;
+ if (!a.lastUpdatedMs)
+ a.lastUpdatedMs = Date.parse(a.lastUpdated);
+ if (!b.lastUpdatedMs)
+ b.lastUpdatedMs = Date.parse(b.lastUpdated);
+ return a.lastUpdatedMs - b.lastUpdatedMs;
},
renderHeaderCell: () => {
return 'Last updated';
@@ -103,10 +85,10 @@ const columns: TableColumnDefinition- [] = [
return new Date(item.lastUpdated).toLocaleString();
}
}),
- createTableColumn
- ({
+ createTableColumn
({
columnId: 'actions',
compare: (a, b) => {
- return a.isLocal === b.isLocal ? 0 : a.isLocal ? -1 : 1;
+ return a.isDownloading ? 0 : a.isLocal ? 1 : 2;
},
renderHeaderCell: () => {
return 'Actions';
@@ -114,54 +96,63 @@ const columns: TableColumnDefinition- [] = [
renderCell: (item) => {
return (
+
+ } onClick={() => {
+ DownloadFile(`./models/${item.name}`, item.downloadUrl);
+ }}/>
+ } onClick={() => {
+ BrowserOpenURL(item.url);
+ }}/>
+
);
}
})
];
-export const Models: FC = () => {
- useEffect(() => {
- fetch('https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner/manifest.json')
- .then(
- res => res.json().then(console.log)
- );
- }, []);
-
+export const Models: FC = observer(() => {
return (
-
-
In Development
+
+
Models
- Model Source Url List
+ Model Source Manifest List
}/>
-
description
+
Provide JSON file URLs for the models manifest. Separate URLs with semicolons. The "models"
+ field in JSON files will be parsed into the following table.
+ defaultValue={commonStore.modelSourceManifestList}
+ onChange={(e, data) => commonStore.setModelSourceManifestList(data.value)}/>
-
-
-
- {({renderHeaderCell}) => (
- {renderHeaderCell()}
- )}
-
-
- >
- {({item, rowId}) => (
- key={rowId}>
- {({renderCell}) => (
- {renderCell(item)}
+
+
+
+
+ {({renderHeaderCell}) => (
+ {renderHeaderCell()}
)}
- )}
-
-
+
+
+ >
+ {({item, rowId}) => (
+ key={rowId}>
+ {({renderCell}) => (
+ {renderCell(item)}
+ )}
+
+ )}
+
+
+
+
);
-};
+});
diff --git a/frontend/src/startup.ts b/frontend/src/startup.ts
new file mode 100644
index 0000000..753eec3
--- /dev/null
+++ b/frontend/src/startup.ts
@@ -0,0 +1,46 @@
+import commonStore, {ModelSourceItem} from './stores/commonStore';
+import {ReadJson, SaveJson} from '../wailsjs/go/backend_golang/App';
+import manifest from '../../manifest.json';
+
+export async function startup() {
+ initConfig();
+}
+
+type Cache = {
+ models: ModelSourceItem[]
+}
+
+async function initConfig() {
+ let cache: Cache = {models: []};
+ await ReadJson('cache.json').then((cacheData: Cache) => {
+ cache = cacheData;
+ }).catch(
+ () => {
+ cache = {models: manifest.models};
+ SaveJson('cache.json', cache).catch(() => {
+ });
+ }
+ );
+ commonStore.setModelSourceList(cache.models);
+
+ const manifestUrls = commonStore.modelSourceManifestList.split(/[,,;;\n]/);
+ const requests = manifestUrls.filter(url => url.endsWith('.json')).map(
+ url => fetch(url, {cache: 'no-cache'}).then(r => r.json()));
+
+ await Promise.allSettled(requests)
+ .then((data: PromiseSettledResult
[]) => {
+ cache.models.push(...data.flatMap(d => {
+ if (d.status === 'fulfilled')
+ return d.value.models;
+ return [];
+ }));
+ })
+ .catch(() => {
+ });
+ cache.models = cache.models.filter((model, index, self) => {
+ return model.name.endsWith('.pth') && index === self.findIndex(m => m.SHA256 === model.SHA256 && m.size === model.size);
+ });
+ commonStore.setModelSourceList(cache.models);
+ SaveJson('cache.json', cache).catch(() => {
+ });
+}
\ No newline at end of file
diff --git a/frontend/src/stores/commonStore.ts b/frontend/src/stores/commonStore.ts
index 1c5f223..f365c31 100644
--- a/frontend/src/stores/commonStore.ts
+++ b/frontend/src/stores/commonStore.ts
@@ -7,15 +7,101 @@ export enum ModelStatus {
Working,
}
+export type ModelSourceItem = {
+ name: string;
+ desc: { [lang: string]: string; };
+ size: number;
+ lastUpdated: string;
+ SHA256: string;
+ url: string;
+ downloadUrl: string;
+ isLocal?: boolean;
+ isDownloading?: boolean;
+ lastUpdatedMs?: number;
+};
+
+export type ApiParameters = {
+ apiPort: number
+ maxResponseToken: number;
+ temperature: number;
+ topP: number;
+ presencePenalty: number;
+ countPenalty: number;
+}
+
+export type ModelParameters = {
+ modelName: string;
+ device: string;
+ precision: string;
+ streamedLayers: number;
+ enableHighPrecisionForLastLayer: boolean;
+}
+
+export type ModelConfig = {
+ configName: string;
+ apiParameters: ApiParameters
+ modelParameters: ModelParameters
+}
+
+const defaultModelConfigs: ModelConfig[] = [
+ {
+ configName: 'Default',
+ apiParameters: {
+ apiPort: 8000,
+ maxResponseToken: 1000,
+ temperature: 1,
+ topP: 1,
+ presencePenalty: 0,
+ countPenalty: 0
+ },
+ modelParameters: {
+ modelName: '124M',
+ device: 'CPU',
+ precision: 'fp32',
+ streamedLayers: 1,
+ enableHighPrecisionForLastLayer: false
+ }
+ }
+];
+
class CommonStore {
constructor() {
makeAutoObservable(this);
}
modelStatus: ModelStatus = ModelStatus.Offline;
- updateModelStatus = (status: ModelStatus) => {
+ currentModelConfigIndex: number = 0;
+ modelConfigs: ModelConfig[] = defaultModelConfigs;
+ modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner/manifest.json;';
+ modelSourceList: ModelSourceItem[] = [];
+
+ setModelStatus = (status: ModelStatus) => {
this.modelStatus = status;
};
+
+ setCurrentConfigIndex = (index: number) => {
+ this.currentModelConfigIndex = index;
+ };
+
+ setModelConfig = (index: number, config: ModelConfig) => {
+ this.modelConfigs[index] = config;
+ };
+
+ createModelConfig = (config: ModelConfig = defaultModelConfigs[0]) => {
+ this.modelConfigs.push(config);
+ };
+
+ deleteModelConfig = (index: number) => {
+ this.modelConfigs.splice(index, 1);
+ };
+
+ setModelSourceManifestList = (value: string) => {
+ this.modelSourceManifestList = value;
+ };
+
+ setModelSourceList = (value: ModelSourceItem[]) => {
+ this.modelSourceList = value;
+ };
}
export default new CommonStore();
\ No newline at end of file
diff --git a/frontend/src/style.css b/frontend/src/style.css
index 41fcbcb..adbbda1 100644
--- a/frontend/src/style.css
+++ b/frontend/src/style.css
@@ -5,6 +5,7 @@
body {
margin: 0;
overflow: hidden;
+ height: 100%;
}
* {
diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts
index 5356f2d..1cad9a3 100644
--- a/frontend/wailsjs/go/backend_golang/App.d.ts
+++ b/frontend/wailsjs/go/backend_golang/App.d.ts
@@ -1,6 +1,14 @@
// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
// This file is automatically generated. DO NOT EDIT
-export function SaveJson(arg1:string,arg2:any):Promise;
+export function DownloadFile(arg1:string,arg2:string):Promise;
+
+export function FileExists(arg1:string):Promise;
+
+export function FileInfo(arg1:string):Promise;
+
+export function ReadJson(arg1:string):Promise;
+
+export function SaveJson(arg1:string,arg2:any):Promise;
export function StartServer(arg1:string,arg2:string):Promise;
diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js
index 36dd033..690f0cb 100644
--- a/frontend/wailsjs/go/backend_golang/App.js
+++ b/frontend/wailsjs/go/backend_golang/App.js
@@ -2,6 +2,22 @@
// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
// This file is automatically generated. DO NOT EDIT
+export function DownloadFile(arg1, arg2) {
+ return window['go']['backend_golang']['App']['DownloadFile'](arg1, arg2);
+}
+
+export function FileExists(arg1) {
+ return window['go']['backend_golang']['App']['FileExists'](arg1);
+}
+
+export function FileInfo(arg1) {
+ return window['go']['backend_golang']['App']['FileInfo'](arg1);
+}
+
+export function ReadJson(arg1) {
+ return window['go']['backend_golang']['App']['ReadJson'](arg1);
+}
+
export function SaveJson(arg1, arg2) {
return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2);
}
diff --git a/go.mod b/go.mod
index 87c9886..694298d 100644
--- a/go.mod
+++ b/go.mod
@@ -2,7 +2,10 @@ module rwkv-runner
go 1.18
-require github.com/wailsapp/wails/v2 v2.4.1
+require (
+ github.com/cavaliergopher/grab/v3 v3.0.1
+ github.com/wailsapp/wails/v2 v2.4.1
+)
require (
github.com/bep/debounce v1.2.1 // indirect
diff --git a/go.sum b/go.sum
index 771a2af..a8a6d77 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,7 @@
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
+github.com/cavaliergopher/grab/v3 v3.0.1 h1:4z7TkBfmPjmLAAmkkAZNX/6QJ1nNFdv3SdIHXju0Fr4=
+github.com/cavaliergopher/grab/v3 v3.0.1/go.mod h1:1U/KNnD+Ft6JJiYoYBAimKH2XrYptb8Kl3DFGmsjpq4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
diff --git a/main.go b/main.go
index 10681bd..8c9da7f 100644
--- a/main.go
+++ b/main.go
@@ -28,7 +28,7 @@ func main() {
Assets: assets,
},
OnStartup: app.OnStartup,
- Bind: []interface{}{
+ Bind: []any{
app,
},
})