Compare commits

...

145 Commits

Author SHA1 Message Date
josc146
e2b086e2f7 release v1.6.2 2023-12-12 23:50:56 +08:00
josc146
da632565d5 fix windows cmd waiting 2023-12-12 23:48:32 +08:00
josc146
556b667cc0 improve prompts 2023-12-12 23:27:19 +08:00
josc146
82c9825da8 rwkv.cpp python38 compatibility 2023-12-12 23:19:18 +08:00
josc146
26b30f0dbe add load failed traceback 2023-12-12 23:16:48 +08:00
josc146
be3b69c65c fix v1.6.1 CmdHelper 2023-12-12 23:04:24 +08:00
github-actions[bot]
07cab6949e release v1.6.1 2023-12-12 14:38:47 +00:00
josc146
18d58ce124 release v1.6.1 2023-12-12 22:38:18 +08:00
josc146
b8f8837a8f allow overriding Core API URL 2023-12-12 22:37:36 +08:00
josc146
0c796c8cfc allow playing mid with external player 2023-12-12 22:13:09 +08:00
josc146
b14fbc29b7 rwkv.cpp(ggml) support 2023-12-12 20:29:55 +08:00
josc146
6e29f97881 update readme 2023-12-11 17:23:09 +08:00
josc146
a164939161 add crash.log 2023-12-11 12:02:24 +08:00
github-actions[bot]
09ab11ef01 release v1.6.0 2023-12-10 15:46:50 +00:00
josc146
ac34edec7f release v1.6.0 2023-12-10 23:46:25 +08:00
josc146
6dd8ffa037 bump to wails v2.7.1 2023-12-10 23:43:40 +08:00
josc146
eaed3f40a2 improve current instrument display 2023-12-10 23:37:23 +08:00
josc146
e48f39375e add midi tracks to webUI 2023-12-10 23:08:44 +08:00
josc146
9b7b651ef9 feat: import midi file 2023-12-10 22:38:31 +08:00
josc146
b5623cb9c2 fix generation instrumentType 2023-12-10 22:32:06 +08:00
josc146
144d12b463 chore 2023-12-10 21:13:36 +08:00
josc146
fa452f5518 bump to wails v2.7.0 2023-12-09 14:56:48 +08:00
josc146
a159d21d45 Update README_JA.md 2023-12-09 13:09:53 +08:00
josc146
3a00bbf44d update readme 2023-12-09 12:56:15 +08:00
github-actions[bot]
9f5e94fa8f release v1.5.9 2023-12-08 11:22:21 +00:00
josc146
87e1daa733 release v1.5.9 2023-12-08 19:22:01 +08:00
josc146
f5900179e0 model tags classifier 2023-12-08 18:17:53 +08:00
josc146
51e162970e always reset to activePreset 2023-12-08 17:10:23 +08:00
josc146
0b339ad0f6 improve ConfigSelector performance of Configs page 2023-12-08 16:36:15 +08:00
josc146
60693d6a29 improve presets interaction 2023-12-08 15:36:53 +08:00
josc146
eea53a6e9e add available tag for model downloaded configs 2023-12-08 15:34:45 +08:00
josc146
8a19181a38 chore 2023-12-08 15:30:46 +08:00
josc146
94d835c7ae better customCuda condition 2023-12-08 15:30:05 +08:00
josc146
d9e25ad69f better state cache 2023-12-08 15:28:33 +08:00
josc146
75244fbd8b disable hashed assets 2023-12-08 11:22:31 +08:00
josc146
5ce84edc3d add web-rwkv-converter (Safetensors Convert no longer depends on Python) 2023-12-07 23:26:39 +08:00
josc146
1c683087f4 update ci webgpu components 2023-12-07 23:04:56 +08:00
josc146
85a3b39cbc fix webWails undefined functions 2023-12-06 23:19:56 +08:00
josc146
cc6c24f0c3 add python-3.10.11-embed-amd64.zip cnMirror and chore 2023-12-06 23:19:22 +08:00
josc146
c733b6419c for devices that gpu is not supported, use cpu to merge lora 2023-12-06 23:17:13 +08:00
josc146
c853c5b60b chore 2023-12-06 23:09:39 +08:00
josc146
053a08f5b7 update convert_safetensors.py 2023-12-06 23:08:40 +08:00
josc146
f7227cd1c1 update ci webgpu components 2023-12-06 23:08:20 +08:00
josc146
861e245062 RWKV_RESCALE_LAYER 999 for music model 2023-12-04 17:51:21 +08:00
josc146
8f0fc7db56 update README_ZH.md 2023-11-30 22:07:16 +08:00
josc146
3dd06fa70e update README_ZH.md 2023-11-30 21:49:31 +08:00
josc146
86a855e7bc fix damaged logo 2023-11-30 21:48:14 +08:00
github-actions[bot]
b3110d4ad8 release v1.5.8 2023-11-30 05:04:31 +00:00
josc146
602004ad34 release v1.5.8 2023-11-30 13:04:02 +08:00
josc146
a8b4f0bb7e lora finetune version check 2023-11-30 13:01:38 +08:00
josc146
24cc8be085 add high loss warning 2023-11-30 12:40:16 +08:00
josc146
a96d7aef8d display mainInstrument of track 2023-11-30 12:36:03 +08:00
josc146
cbe299583b improve details of MIDI Input 2023-11-30 11:57:52 +08:00
josc146
68c70a362b darkmode of midi tracks 2023-11-30 11:56:45 +08:00
josc146
a78c346371 fix NoteOff ElapsedTime of MIDI Tracks 2023-11-30 11:55:10 +08:00
github-actions[bot]
102763b94d release v1.5.7 2023-11-29 15:01:26 +00:00
josc146
ad65765ba8 release v1.5.7 2023-11-29 22:59:47 +08:00
josc146
d04fd7cb87 fix lib 2023-11-29 22:59:42 +08:00
github-actions[bot]
b398cbb591 release v1.5.6 2023-11-29 13:22:21 +00:00
josc146
19b97e985c release v1.5.6 2023-11-29 21:21:50 +08:00
josc146
93bf74a320 fix NoteOff 2023-11-29 21:21:42 +08:00
josc146
7daae23bbb update defaultConfigs 2023-11-29 21:21:29 +08:00
josc146
0d0a3f15cc chore 2023-11-29 21:21:14 +08:00
github-actions[bot]
04fbb38861 release v1.5.5 2023-11-29 11:32:40 +00:00
josc146
d666c6032b release v1.5.5 2023-11-29 19:31:56 +08:00
josc146
93e8660d69 add instruments i18n 2023-11-29 19:31:52 +08:00
josc146
e687cf02bb try to use local soundfont by default 2023-11-29 19:17:19 +08:00
josc146
e858f1477a update locales 2023-11-29 19:10:01 +08:00
josc146
a2062ae9cc feat: save MIDI tracks to generation area; playing tracks and audio preview are still under development 2023-11-29 19:04:41 +08:00
josc146
34112c79c7 fix autoPlayed midi cannot be stopped 2023-11-29 15:28:43 +08:00
josc146
b625b8a6d1 MIDI Recording and details improvement 2023-11-29 14:05:58 +08:00
josc146
14a13d5768 basic MIDI Input Audio Tracks 2023-11-28 15:34:06 +08:00
josc146
7ce464ecda improve details 2023-11-26 22:54:59 +08:00
github-actions[bot]
2c1f89383f release v1.5.4 2023-11-24 11:22:42 +00:00
josc146
e666c50f77 release v1.5.4 2023-11-24 19:22:07 +08:00
josc146
1b441752b0 chore 2023-11-24 19:21:58 +08:00
josc146
e01897b24d improve launch flow of webgpu mode 2023-11-24 19:21:14 +08:00
josc146
6146d910b4 improve launch flow of webgpu mode 2023-11-24 18:36:44 +08:00
josc146
0063c171f3 upgrade to rwkv 0.8.22 (rwkv6 support) 2023-11-24 17:55:16 +08:00
josc146
bea3c29c1c update defaultConfigs 2023-11-24 17:13:22 +08:00
josc146
5f543c2545 update manifest 2023-11-24 16:35:21 +08:00
josc146
177b2c54d9 allow reading attachments even if the model is offline 2023-11-24 16:25:21 +08:00
josc146
645e8e2f44 chore 2023-11-24 15:58:53 +08:00
josc146
f2d0dda2ff allow safetensors converter on macOS 2023-11-21 22:32:25 +08:00
josc146
3a449e7b46 fix fs watcher of macOS 2023-11-21 22:30:42 +08:00
github-actions[bot]
18d2ecb7a7 release v1.5.3 2023-11-20 16:22:32 +00:00
josc146
bb3a93b419 release v1.5.3 2023-11-21 00:21:09 +08:00
josc146
1334f0e5ba chore 2023-11-21 00:20:54 +08:00
josc146
8781416cfb add hf-mirror for cn users 2023-11-21 00:04:23 +08:00
josc146
a9819139b8 add sidePanel for Chat page 2023-11-20 23:47:39 +08:00
josc146
66e43c9d9b display lastModelName at the top (WorkHeader) 2023-11-20 23:27:44 +08:00
josc146
41e5bd5eb8 change ValuedSlider's step to 100 2023-11-20 23:25:39 +08:00
josc146
48fef0235b add webgpu nf4 2023-11-20 21:10:10 +08:00
josc146
d435436525 improve finetune error 2023-11-20 20:39:00 +08:00
josc146
cd7a9896dc improve styles 2023-11-20 20:16:55 +08:00
josc146
bbcc6b07b6 improve precision description 2023-11-20 20:13:30 +08:00
josc146
646bcd81c0 fix webgpu permission for macos 2023-11-20 20:12:20 +08:00
josc146
dbf0dccc9d add tokenizer(/switch-model) to /docs 2023-11-20 20:11:45 +08:00
josc146
437de2be20 improve lazy loading ui 2023-11-18 13:59:37 +08:00
josc146
f739c61197 fix a finetune bug 2023-11-17 22:37:21 +08:00
josc146
01d3c89ea4 add rwkv API URL Option; update OpenAI models Option 2023-11-17 22:16:49 +08:00
josc146
d18218f21a use local API when it's working, even if a custom API URL is provided 2023-11-17 21:53:29 +08:00
josc146
c8470e77fd fix state_cache of deploy mode 2023-11-17 21:32:11 +08:00
josc146
9ede7d7c6d strict default_stop 2023-11-17 21:18:52 +08:00
josc146
a59c4436c8 macos: change default webgpu backend to aarch64-apple-darwin 2023-11-17 21:16:08 +08:00
josc146
068be2bfc4 update setup comments 2023-11-17 20:47:33 +08:00
josc146
94a5dc4fb7 update setup.sh comments 2023-11-14 17:38:24 +08:00
github-actions[bot]
9f288de951 release v1.5.2 2023-11-09 14:11:40 +00:00
josc146
3d5c3dcd31 release v1.5.2 2023-11-09 22:11:05 +08:00
josc146
0a4876a564 improve user guide 2023-11-09 22:07:01 +08:00
josc146
4f0558ae34 add client upgrade progress 2023-11-09 21:38:02 +08:00
josc146
f03c9cf25f improve mobile view 2023-11-09 12:21:01 +08:00
josc146
07797537d1 add RWKV-Runner WebUI to Server-Deploy-Examples 2023-11-09 00:21:02 +08:00
github-actions[bot]
0c3a50cb07 release v1.5.1 2023-11-08 15:41:53 +00:00
josc146
c7dcff52a1 release v1.5.1 2023-11-08 23:41:17 +08:00
josc146
c6ef32958e when client webUI enabled, set server into deployment mode 2023-11-08 23:31:13 +08:00
josc146
7235e1067b add deployment mode. If /switch-model with deploy: true, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs) 2023-11-08 23:29:42 +08:00
josc146
0594290b92 disable WebUI Option of WebGPU Mode (webgpu not supported yet) 2023-11-08 23:05:59 +08:00
josc146
d249a4c29a print error.txt 2023-11-08 22:57:38 +08:00
josc146
02ba37fab4 improve api url getter 2023-11-08 22:25:41 +08:00
josc146
b5a6f8a425 set deepspeed to 0.11.2 to avoid finetune error 2023-11-08 22:20:11 +08:00
josc146
1ad86d737c chore 2023-11-08 22:18:49 +08:00
josc146
cfa3669f6f fix /docs default api params (Pydantic v2) 2023-11-07 22:53:11 +08:00
josc146
26d4c9f0ed chore 2023-11-07 22:28:13 +08:00
josc146
3ddcf9f62e add webui entry 2023-11-07 22:24:06 +08:00
josc146
e734fce64f create webui assets 2023-11-07 22:23:26 +08:00
josc146
150beb578c chore 2023-11-07 22:23:00 +08:00
josc146
db6fbe8366 add python webui server 2023-11-07 22:22:29 +08:00
josc146
46f52923c3 improve webui 2023-11-07 22:21:41 +08:00
josc146
893be5cf43 webui build 2023-11-07 19:27:21 +08:00
github-actions[bot]
384e4ce4d0 release v1.5.0 2023-11-05 13:10:50 +00:00
josc146
b8712e0b89 release v1.5.0 2023-11-05 21:10:21 +08:00
josc146
37dda4333d chat attachment is now related to single message 2023-11-05 21:05:06 +08:00
josc146
64826b9af7 fix log encoding error 2023-11-05 21:00:31 +08:00
josc146
47b0c35441 update ngrok_connect 2023-11-04 20:22:28 +08:00
josc146
1dcda47013 improve startup process 2023-11-04 20:21:55 +08:00
josc146
1f81a1e5a8 upgrade to rwkv 0.8.20 2023-11-03 23:27:14 +08:00
josc146
35e92d2aef chore 2023-11-03 23:22:52 +08:00
josc146
0d99e5549e port occupied detection 2023-11-03 21:18:42 +08:00
josc146
fed1594ddc fix stop button status of Chat page 2023-10-30 21:09:23 +08:00
josc146
14b90bb36b improve dml mode performance (20% faster, https://github.com/BlinkDL/ChatRWKV/pull/181) 2023-10-30 20:24:57 +08:00
josc146
f86b7f1f08 python38 compatibility 2023-10-29 14:11:11 +08:00
josc146
54355d5a7a improve the compatibility between frontend presets and chatgpt api 2023-10-28 23:06:19 +08:00
josc146
ff7306349a improve memory usage of state cache 2023-10-28 23:04:49 +08:00
github-actions[bot]
77df56cddc release v1.4.9 2023-10-27 06:04:00 +00:00
110 changed files with 7496 additions and 2521 deletions

1
.gitattributes vendored
View File

@@ -3,6 +3,7 @@ backend-python/wkv_cuda_utils/** linguist-vendored
backend-python/get-pip.py linguist-vendored
backend-python/convert_model.py linguist-vendored
backend-python/convert_safetensors.py linguist-vendored
backend-python/convert_pytorch_to_ggml.py linguist-vendored
backend-python/utils/midi.py linguist-vendored
build/** linguist-vendored
finetune/lora/** linguist-vendored

View File

@@ -48,16 +48,13 @@ jobs:
id: cp310
with:
python-version: '3.10'
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
target: wasm32-unknown-unknown
- uses: crazy-max/ghaction-chocolatey@v2
with:
args: install upx
- run: |
Start-BitsTransfer https://github.com/josStorer/LibreHardwareMonitor.Console/releases/download/v0.1.0/LibreHardwareMonitor.Console.zip ./LibreHardwareMonitor.Console.zip
Start-BitsTransfer https://github.com/josStorer/ai00_rwkv_server/releases/latest/download/webgpu_server_windows_x86_64.exe ./backend-rust/webgpu_server.exe
Start-BitsTransfer https://github.com/josStorer/web-rwkv-converter/releases/latest/download/web-rwkv-converter_windows_x86_64.exe ./backend-rust/web-rwkv-converter.exe
Start-BitsTransfer https://github.com/josStorer/LibreHardwareMonitor.Console/releases/latest/download/LibreHardwareMonitor.Console.zip ./LibreHardwareMonitor.Console.zip
Expand-Archive ./LibreHardwareMonitor.Console.zip -DestinationPath ./components/LibreHardwareMonitor.Console
Start-BitsTransfer https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip ./python-3.10.11-embed-amd64.zip
Expand-Archive ./python-3.10.11-embed-amd64.zip -DestinationPath ./py310
@@ -67,13 +64,11 @@ jobs:
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../include" -Destination "py310/include" -Recurse
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../libs" -Destination "py310/libs" -Recurse
./py310/python -m pip install cyac==1.9
git clone https://github.com/josStorer/ai00_rwkv_server --depth=1
cd ai00_rwkv_server
cargo build --release
mv ./target/release/ai00_server.exe ../backend-rust/webgpu_server.exe
cd ..
go install github.com/wailsapp/wails/v2/cmd/wails@latest
del ./backend-python/rwkv_pip/cpp/librwkv.dylib
del ./backend-python/rwkv_pip/cpp/librwkv.so
(Get-Content -Path ./backend-golang/app.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/app.go
(Get-Content -Path ./backend-golang/utils.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/utils.go
make
Rename-Item -Path "build/bin/RWKV-Runner.exe" -NewName "RWKV-Runner_windows_x64.exe"
@@ -89,28 +84,20 @@ jobs:
- uses: actions/setup-go@v4
with:
go-version: '1.20.5'
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
target: wasm32-unknown-unknown
- run: |
wget https://github.com/josStorer/ai00_rwkv_server/releases/latest/download/webgpu_server_linux_x86_64 -O ./backend-rust/webgpu_server
wget https://github.com/josStorer/web-rwkv-converter/releases/latest/download/web-rwkv-converter_linux_x86_64 -O ./backend-rust/web-rwkv-converter
sudo apt-get update
sudo apt-get install upx
sudo apt-get install build-essential libgtk-3-dev libwebkit2gtk-4.0-dev
git clone https://github.com/josStorer/ai00_rwkv_server --depth=1
cd ai00_rwkv_server
sudo apt-get install libudev-dev
sudo apt-get install libasound2-dev
rustup target add x86_64-unknown-linux-gnu
cargo build --release --target x86_64-unknown-linux-gnu
mv ./target/x86_64-unknown-linux-gnu/release/ai00_server ../backend-rust/webgpu_server
cd ..
sudo apt-get install build-essential libgtk-3-dev libwebkit2gtk-4.0-dev libasound2-dev
go install github.com/wailsapp/wails/v2/cmd/wails@latest
rm ./backend-python/rwkv_pip/wkv_cuda.pyd
rm ./backend-python/rwkv_pip/rwkv5.pyd
rm ./backend-python/rwkv_pip/rwkv6.pyd
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
rm ./backend-python/get-pip.py
rm ./backend-python/rwkv_pip/cpp/librwkv.dylib
rm ./backend-python/rwkv_pip/cpp/rwkv.dll
make
mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64
@@ -126,22 +113,17 @@ jobs:
- uses: actions/setup-go@v4
with:
go-version: '1.20.5'
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
target: wasm32-unknown-unknown
- run: |
git clone https://github.com/josStorer/ai00_rwkv_server --depth=1
cd ai00_rwkv_server
cargo build --release
mv ./target/release/ai00_server ../backend-rust/webgpu_server
cd ..
wget https://github.com/josStorer/ai00_rwkv_server/releases/latest/download/webgpu_server_darwin_aarch64 -O ./backend-rust/webgpu_server
wget https://github.com/josStorer/web-rwkv-converter/releases/latest/download/web-rwkv-converter_darwin_aarch64 -O ./backend-rust/web-rwkv-converter
go install github.com/wailsapp/wails/v2/cmd/wails@latest
rm ./backend-python/rwkv_pip/wkv_cuda.pyd
rm ./backend-python/rwkv_pip/rwkv5.pyd
rm ./backend-python/rwkv_pip/rwkv6.pyd
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
rm ./backend-python/get-pip.py
rm ./backend-python/rwkv_pip/cpp/rwkv.dll
rm ./backend-python/rwkv_pip/cpp/librwkv.so
make
cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt
cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal

3
.gitignore vendored
View File

@@ -8,6 +8,7 @@ __pycache__
*.st
*.safetensors
*.bin
*.mid
/config.json
/cache.json
/presets.json
@@ -18,7 +19,7 @@ __pycache__
/cmd-helper.bat
/install-py-dep.bat
/backend-python/wkv_cuda
/backend-python/rwkv5
/backend-python/rwkv*
*.exe
*.old
.DS_Store

View File

@@ -1,45 +1,13 @@
## Changes
### Features
- allow conversation with some document (.pdf, .txt) (Experimental)
- add `/file-to-text` api
- allow avatarImg to be local absolute path
- base64 preset support
### Upgrades
- upgrade to rwkv 0.8.16 (DirectML support; rwkv 5.2 no longer needs to ensure custom cuda kernel enabled)
- upgrade to webgpu 0.2.2 (WebGPU Mode is now recommended for AMD and Intel
Users) (https://github.com/josStorer/ai00_rwkv_server)
- upgrade python packages
### Improvements
- improve cuda kernel compatibility (compute compatibility 5.3, Jetson Nano, Nvidia 10 Series+)
- RWKVType now no longer relies on the file name (use emb)
- improve message interruption and retry for Chat page
- update sample.jsonl of LoRA finetune
- update api stop strategy for better custom user_name and assistant_name support
- edited chat message now is marked as Normal
- change default World series prefix to User/Assistant
### Chores
- update manifest.json (RWKV-5)
- update readme and client text description
- add pip --no-warn-script-location
- mark rwkv raven series as old model
- chore
### Fixes
- fix linux kernel (partial revert 68228a45)
- fix the `make` command on Linux and macOS, no longer need manual operations on the wsl.go file. (#158, #173, #207)
- rwkv.cpp python38 compatibility
- improve rwkv.cpp operation prompts
- add load failed traceback
- fix windows cmd waiting
## Install
- Windows: https://github.com/josStorer/RWKV-Runner/blob/master/build/windows/Readme_Install.txt
- MacOS: https://github.com/josStorer/RWKV-Runner/blob/master/build/darwin/Readme_Install.txt
- Linux: https://github.com/josStorer/RWKV-Runner/blob/master/build/linux/Readme_Install.txt
- Server-Deploy-Examples: https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples
- Server-Deploy-Examples: https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples

View File

@@ -8,16 +8,26 @@ endif
build-windows:
@echo ---- build for windows
wails build -upx -ldflags "-s -w" -platform windows/amd64
wails build -upx -ldflags '-s -w -extldflags "-static"' -platform windows/amd64
build-macos:
@echo ---- build for macos
wails build -ldflags "-s -w" -platform darwin/universal
wails build -ldflags '-s -w' -platform darwin/universal
build-linux:
@echo ---- build for linux
wails build -upx -ldflags "-s -w" -platform linux/amd64
wails build -upx -ldflags '-s -w' -platform linux/amd64
build-web:
@echo ---- build for web
cd frontend && npm run build
dev:
wails dev
dev-web:
cd frontend && npm run dev
preview:
cd frontend && npm run preview

113
README.md
View File

@@ -21,7 +21,7 @@ English | [简体中文](README_ZH.md) | [日本語](README_JA.md)
[![MacOS][MacOS-image]][MacOS-url]
[![Linux][Linux-image]][Linux-url]
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [Preview](#Preview) | [Download][download-url] | [Server-Deploy-Examples](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [Preview](#Preview) | [Download][download-url] | [Simple Deploy Example](#Simple-Deploy-Example) | [Server Deploy Examples](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples) | [MIDI Hardware Input](#MIDI-Input)
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
@@ -57,20 +57,49 @@ English | [简体中文](README_ZH.md) | [日本語](README_JA.md)
## Features
- RWKV model management and one-click startup
- Fully compatible with the OpenAI API, making every ChatGPT client an RWKV client. After starting the model,
- RWKV model management and one-click startup.
- Front-end and back-end separation, if you don't want to use the client, also allows for separately deploying the
front-end service, or the back-end inference service, or the back-end inference service with a WebUI.
[Simple Deploy Example](#Simple-Deploy-Example) | [Server Deploy Examples](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
- Compatible with the OpenAI API, making every ChatGPT client an RWKV client. After starting the model,
open http://127.0.0.1:8000/docs to view more details.
- Automatic dependency installation, requiring only a lightweight executable program
- Configs with 2G to 32G VRAM are included, works well on almost all computers
- User-friendly chat and completion interaction interface included
- Easy-to-understand and operate parameter configuration
- Built-in model conversion tool
- Built-in download management and remote model inspection
- Built-in one-click LoRA Finetune
- Can also be used as an OpenAI ChatGPT and GPT-Playground client
- Multilingual localization
- Theme switching
- Automatic updates
- Automatic dependency installation, requiring only a lightweight executable program.
- Pre-set multi-level VRAM configs, works well on almost all computers. In Configs page, switch Strategy to WebGPU, it
can also run on AMD, Intel, and other graphics cards.
- User-friendly chat, completion, and composition interaction interface included. Also supports chat presets, attachment
uploads, MIDI hardware input, and track editing.
[Preview](#Preview) | [MIDI Hardware Input](#MIDI-Input)
- Built-in WebUI option, one-click start of Web service, sharing your hardware resources.
- Easy-to-understand and operate parameter configuration, along with various operation guidance prompts.
- Built-in model conversion tool.
- Built-in download management and remote model inspection.
- Built-in one-click LoRA Finetune. (Windows Only)
- Can also be used as an OpenAI ChatGPT and GPT-Playground client. (Fill in the API URL and API Key in Settings page)
- Multilingual localization.
- Theme switching.
- Automatic updates.
## Simple Deploy Example
```bash
git clone https://github.com/josStorer/RWKV-Runner
# Then
cd RWKV-Runner
python ./backend-python/main.py #The backend inference service has been started, request /switch-model API to load the model, refer to the API documentation: http://127.0.0.1:8000/docs
# Or
cd RWKV-Runner/frontend
npm ci
npm run build #Compile the frontend
cd ..
python ./backend-python/webui_server.py #Start the frontend service separately
# Or
python ./backend-python/main.py --webui #Start the frontend and backend service at the same time
# Help Info
python ./backend-python/main.py -h
```
## API Concurrency Stress Testing
@@ -133,6 +162,48 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
```
## MIDI Input
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
it in the source code directory.
### USB MIDI Connection
- USB MIDI devices are plug-and-play, and you can select your input device in the Composition page
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/13bb92c3-4504-482d-ab82-026ac6c31095)
### Mac MIDI Bluetooth Connection
- For Mac users who want to use Bluetooth input,
please install [Bluetooth MIDI Connect](https://apps.apple.com/us/app/bluetooth-midi-connect/id1108321791), then click
the tray icon to connect after launching,
afterwards, you can select your input device in the Composition page.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c079a109-1e3d-45c1-bbf5-eed85da1550e)
### Windows MIDI Bluetooth Connection
- Windows seems to have implemented Bluetooth MIDI support only for UWP (Universal Windows Platform) apps. Therefore, it
requires multiple steps to establish a connection. We need to create a local virtual MIDI device and then launch a UWP
application. Through this UWP application, we will redirect Bluetooth MIDI input to the virtual MIDI device, and then
this software will listen to the input from the virtual MIDI device.
- So, first, you need to
download [loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip)
to create a virtual MIDI device. Click the plus sign in the bottom left corner to create the device.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b75998ff-115c-4ddd-b97c-deeb5c106255)
- Next, you need to download [Bluetooth LE Explorer](https://apps.microsoft.com/detail/9N0ZTKF1QD98) to discover and
connect to Bluetooth MIDI devices. Click "Start" to search for devices, and then click "Pair" to bind the MIDI device.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c142c3ea-a973-4531-9807-4c385d640a2b)
- Finally, you need to install [MIDIberry](https://apps.microsoft.com/detail/9N39720H2M05),
This UWP application can redirect Bluetooth MIDI input to the virtual MIDI device. After launching it, double-click
your actual Bluetooth MIDI device name in the input field, and in the output field, double-click the virtual MIDI
device name we created earlier.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/5ad6a1d9-4f68-4d95-ae17-4296107d1669)
- Now, you can select the virtual MIDI device as the input in the Composition page. Bluetooth LE Explorer no longer
needs to run, and you can also close the loopMIDI window, it will run automatically in the background. Just keep
MIDIberry open.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/1c371821-c7b7-4c18-8e42-9e315efbe427)
## Related Repositories:
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
@@ -146,27 +217,35 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
### Homepage
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/d7f24d80-f382-428d-8b28-edf87e1549e2)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c9b9cdd0-63f9-4319-9f74-5bf5d7df5a67)
### Chat
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/80009872-528f-4932-aeb2-f724fa892e7c)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e98c9038-3323-47b0-8edb-d639fafd37b2)
### Completion
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/bf49de8e-3b89-4543-b1ef-7cd4b19a1836)
### Composition
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
it in the source code directory.
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e8ad908d-3fd2-4e92-bcdb-96815cb836ee)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b2ce4761-9e75-477e-a182-d0255fb8ac76)
### Configuration
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/48befdc6-e03c-4851-9bee-22f77ee2640e)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/f41060dc-5517-44af-bb3f-8ef71720016d)
### Model Management
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/367fe4f8-cc12-475f-9371-3cf62cdbf293)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b1581147-a6ce-4493-8010-e33c0ddeca0a)
### Download Management

View File

@@ -21,7 +21,7 @@
[![MacOS][MacOS-image]][MacOS-url]
[![Linux][Linux-image]][Linux-url]
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [プレビュー](#Preview) | [ダウンロード][download-url] | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [プレビュー](#Preview) | [ダウンロード][download-url] | [シンプルなデプロイの例](#Simple-Deploy-Example) | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples) | [MIDIハードウェア入力](#MIDI-Input)
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
@@ -58,20 +58,47 @@
## 特徴
- RWKV モデル管理とワンクリック起動
- OpenAI API と完全に互換性があり、すべての ChatGPT クライアントを RWKV クライアントにします。モデル起動後、
- フロントエンドとバックエンドの分離は、クライアントを使用しない場合でも、フロントエンドサービス、またはバックエンド推論サービス、またはWebUIを備えたバックエンド推論サービスを個別に展開することを可能にします。
[シンプルなデプロイの例](#Simple-Deploy-Example) | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
- OpenAI API と互換性があり、すべての ChatGPT クライアントを RWKV クライアントにします。モデル起動後、
http://127.0.0.1:8000/docs を開いて詳細をご覧ください。
- 依存関係の自動インストールにより、軽量な実行プログラムのみを必要とします
- 2G から 32G の VRAM のコンフィグが含まれており、ほとんどのコンピュータで動作します
- ユーザーフレンドリーなチャット完成インタラクションインターフェースを搭載
- 分かりやすく操作しやすいパラメータ設定
- 事前設定された多段階のVRAM設定、ほとんどのコンピュータで動作します。配置ページで、ストラテジーをWebGPUに切り替えると、AMD、インテル、その他のグラフィックカードでも動作します
- ユーザーフレンドリーなチャット完成、および作曲インターフェイスが含まれています。また、チャットプリセット、添付ファイルのアップロード、MIDIハードウェア入力、トラック編集もサポートしています。
[プレビュー](#Preview) | [MIDIハードウェア入力](#MIDI-Input)
- 内蔵WebUIオプション、Webサービスのワンクリック開始、ハードウェアリソースの共有
- 分かりやすく操作しやすいパラメータ設定、各種操作ガイダンスプロンプトとともに
- 内蔵モデル変換ツール
- ダウンロード管理とリモートモデル検査機能内蔵
- 内蔵のLoRA微調整機能を搭載しています
- このプログラムは、OpenAI ChatGPTとGPT Playgroundのクライアントとしても使用できます
- 内蔵のLoRA微調整機能を搭載しています (Windowsのみ)
- このプログラムは、OpenAI ChatGPTとGPT Playgroundのクライアントとしても使用できます(設定ページで `API URL``API Key`
を入力してください)
- 多言語ローカライズ
- テーマ切り替え
- 自動アップデート
## Simple Deploy Example
```bash
git clone https://github.com/josStorer/RWKV-Runner
# Then
cd RWKV-Runner
python ./backend-python/main.py #The backend inference service has been started, request /switch-model API to load the model, refer to the API documentation: http://127.0.0.1:8000/docs
# Or
cd RWKV-Runner/frontend
npm ci
npm run build #Compile the frontend
cd ..
python ./backend-python/webui_server.py #Start the frontend service separately
# Or
python ./backend-python/main.py --webui #Start the frontend and backend service at the same time
# Help Info
python ./backend-python/main.py -h
```
## API 同時実行ストレステスト
```bash
@@ -134,6 +161,48 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
```
## MIDI Input
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
it in the source code directory.
### USB MIDI Connection
- USB MIDI devices are plug-and-play, and you can select your input device in the Composition page
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/13bb92c3-4504-482d-ab82-026ac6c31095)
### Mac MIDI Bluetooth Connection
- For Mac users who want to use Bluetooth input,
please install [Bluetooth MIDI Connect](https://apps.apple.com/us/app/bluetooth-midi-connect/id1108321791), then click
the tray icon to connect after launching,
afterwards, you can select your input device in the Composition page.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c079a109-1e3d-45c1-bbf5-eed85da1550e)
### Windows MIDI Bluetooth Connection
- Windows seems to have implemented Bluetooth MIDI support only for UWP (Universal Windows Platform) apps. Therefore, it
requires multiple steps to establish a connection. We need to create a local virtual MIDI device and then launch a UWP
application. Through this UWP application, we will redirect Bluetooth MIDI input to the virtual MIDI device, and then
this software will listen to the input from the virtual MIDI device.
- So, first, you need to
download [loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip)
to create a virtual MIDI device. Click the plus sign in the bottom left corner to create the device.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b75998ff-115c-4ddd-b97c-deeb5c106255)
- Next, you need to download [Bluetooth LE Explorer](https://apps.microsoft.com/detail/9N0ZTKF1QD98) to discover and
connect to Bluetooth MIDI devices. Click "Start" to search for devices, and then click "Pair" to bind the MIDI device.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c142c3ea-a973-4531-9807-4c385d640a2b)
- Finally, you need to install [MIDIberry](https://apps.microsoft.com/detail/9N39720H2M05),
This UWP application can redirect Bluetooth MIDI input to the virtual MIDI device. After launching it, double-click
your actual Bluetooth MIDI device name in the input field, and in the output field, double-click the virtual MIDI
device name we created earlier.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/5ad6a1d9-4f68-4d95-ae17-4296107d1669)
- Now, you can select the virtual MIDI device as the input in the Composition page. Bluetooth LE Explorer no longer
needs to run, and you can also close the loopMIDI window, it will run automatically in the background. Just keep
MIDIberry open.
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/1c371821-c7b7-4c18-8e42-9e315efbe427)
## 関連リポジトリ:
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
@@ -143,31 +212,39 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
## プレビュー
## Preview
### ホームページ
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/d7f24d80-f382-428d-8b28-edf87e1549e2)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c9b9cdd0-63f9-4319-9f74-5bf5d7df5a67)
### チャット
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/80009872-528f-4932-aeb2-f724fa892e7c)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e98c9038-3323-47b0-8edb-d639fafd37b2)
### 補完
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/bf49de8e-3b89-4543-b1ef-7cd4b19a1836)
### 作曲
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
it in the source code directory.
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e8ad908d-3fd2-4e92-bcdb-96815cb836ee)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b2ce4761-9e75-477e-a182-d0255fb8ac76)
### コンフィグ
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/48befdc6-e03c-4851-9bee-22f77ee2640e)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/f41060dc-5517-44af-bb3f-8ef71720016d)
### モデル管理
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/367fe4f8-cc12-475f-9371-3cf62cdbf293)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b1581147-a6ce-4493-8010-e33c0ddeca0a)
### ダウンロード管理

View File

@@ -20,7 +20,7 @@ API兼容的接口这意味着一切ChatGPT客户端都是RWKV客户端。
[![MacOS][MacOS-image]][MacOS-url]
[![Linux][Linux-image]][Linux-url]
[视频演示](https://www.bilibili.com/video/BV1hM4y1v76R) | [疑难解答](https://www.bilibili.com/read/cv23921171) | [预览](#Preview) | [下载][download-url] | [懒人包](https://pan.baidu.com/s/1zdzZ_a0uM3gDqi6pXIZVAA?pwd=1111) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
[视频演示](https://www.bilibili.com/video/BV1hM4y1v76R) | [疑难解答](https://www.bilibili.com/read/cv23921171) | [预览](#Preview) | [下载][download-url] | [懒人包](https://pan.baidu.com/s/1zdzZ_a0uM3gDqi6pXIZVAA?pwd=1111) | [简明服务部署示例](#Simple-Deploy-Example) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples) | [MIDI硬件输入](#MIDI-Input)
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
@@ -57,19 +57,45 @@ API兼容的接口这意味着一切ChatGPT客户端都是RWKV客户端。
## 功能
- RWKV模型管理一键启动
- 与OpenAI API完全兼容一切ChatGPT客户端都是RWKV客户端。启动模型后打开 http://127.0.0.1:8000/docs 查看详细内容
- 前后端分离如果你不想使用客户端也允许单独部署前端服务或后端推理服务或具有WebUI的后端推理服务。
[简明服务部署示例](#Simple-Deploy-Example) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
- 与OpenAI API兼容一切ChatGPT客户端都是RWKV客户端。启动模型后打开 http://127.0.0.1:8000/docs 查看API文档
- 全自动依赖安装,你只需要一个轻巧的可执行程序
- 预设了2G至32G显存配置,几乎在各种电脑上工作良好
- 自带用户友好的聊天续写交互页面
- 易于理解和操作的参数配置
- 预设多级显存配置,几乎在各种电脑上工作良好。通过配置页面切换Strategy到WebGPU还可以在AMDIntel等显卡上运行
- 自带用户友好的聊天续写,作曲交互页面。支持聊天预设附件上传MIDI硬件输入及音轨编辑。
[预览](#Preview) | [MIDI硬件输入](#MIDI-Input)
- 内置WebUI选项一键启动Web服务共享硬件资源
- 易于理解和操作的参数配置,及各类操作引导提示
- 内置模型转换工具
- 内置下载管理和远程模型检视
- 内置一键LoRA微调
- 也可用作 OpenAI ChatGPT 和 GPT Playground 客户端
- 内置一键LoRA微调 (仅限Windows)
- 也可用作 OpenAI ChatGPT 和 GPT Playground 客户端 (在设置内填写API URL和API Key)
- 多语言本地化
- 主题切换
- 自动更新
## Simple Deploy Example
```bash
git clone https://github.com/josStorer/RWKV-Runner
# 然后
cd RWKV-Runner
python ./backend-python/main.py #后端推理服务已启动, 调用/switch-model载入模型, 参考API文档: http://127.0.0.1:8000/docs
# 或者
cd RWKV-Runner/frontend
npm ci
npm run build #编译前端
cd ..
python ./backend-python/webui_server.py #单独启动前端服务
# 或者
python ./backend-python/main.py --webui #同时启动前后端服务
# 帮助参数
python ./backend-python/main.py -h
```
## API并发压力测试
```bash
@@ -130,6 +156,40 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
```
## MIDI Input
小贴士: 你可以下载 https://github.com/josStorer/sgm_plus, 并解压到程序的`assets/sound-font`目录, 以使用离线音源. 注意,
如果你正在从源码编译程序, 请不要将其放置在源码目录中
### USB MIDI 连接
- USB MIDI设备是即插即用的, 你能够在作曲页面选择你的输入设备
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/a448c34a-56d8-46eb-8dc2-dd11e8e0c4ce)
### Mac MIDI 蓝牙连接
- 对于想要使用蓝牙输入的Mac用户,
请安装[Bluetooth MIDI Connect](https://apps.apple.com/us/app/bluetooth-midi-connect/id1108321791), 启动后点击托盘连接,
之后你可以在作曲页面选择你的输入设备
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c079a109-1e3d-45c1-bbf5-eed85da1550e)
### Windows MIDI 蓝牙连接
- Windows似乎只为UWP实现了蓝牙MIDI支持, 因此需要多个步骤进行连接, 我们需要创建一个本地的虚拟MIDI设备, 然后启动一个UWP应用,
通过此UWP应用将蓝牙MIDI输入重定向到虚拟MIDI设备, 然后本软件监听虚拟MIDI设备的输入
- 因此, 首先你需要下载[loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip),
用于创建虚拟MIDI设备, 点击左下角的加号创建设备
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b75998ff-115c-4ddd-b97c-deeb5c106255)
- 然后, 你需要下载[Bluetooth LE Explorer](https://apps.microsoft.com/detail/9N0ZTKF1QD98), 以发现并连接蓝牙MIDI设备,
点击Start搜索设备, 然后点击Pair绑定MIDI设备
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c142c3ea-a973-4531-9807-4c385d640a2b)
- 最后, 你需要安装[MIDIberry](https://apps.microsoft.com/detail/9N39720H2M05), 这个UWP应用能将MIDI蓝牙输入重定向到虚拟MIDI设备,
启动后, 在输入栏, 双击你实际的蓝牙MIDI设备名称, 在输出栏, 双击我们先前创建的虚拟MIDI设备名称
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/5ad6a1d9-4f68-4d95-ae17-4296107d1669)
- 现在, 你可以在作曲页面选择虚拟MIDI设备作为输入. Bluetooth LE Explorer不再需要运行, loopMIDI窗口也可以退出, 它会自动在后台运行,
仅保持MIDIberry打开即可
- ![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6460c355-884e-4b28-a2eb-8ab7a2e3a01a)
## 相关仓库:
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
@@ -143,27 +203,34 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
### 主页
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/ff2b1eef-dd3b-4cbf-98fb-b5a1ecee43e1)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/cd82674e-3ee3-4175-bd9c-a11d45437327)
### 聊天
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/9570e73b-dca2-4316-9e92-09961f3c48c4)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/54bb0e2b-cdc4-4ea0-8d16-9beaf57c232c)
### 续写
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/69f9ba7a-2fe8-4a5e-94cb-aa655aa409e2)
### 作曲
小贴士: 你可以下载 https://github.com/josStorer/sgm_plus, 并解压到程序的`assets/sound-font`目录, 以使用离线音源. 注意,
如果你正在从源码编译程序, 请不要将其放置在源码目录中
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/95b34893-80c2-4706-87f9-bc141032ed4b)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/3cb31ca8-d708-42f1-8768-1605fb0b2174)
### 配置
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/59460f69-b172-4c7a-86cb-573262543076)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/0f4d4f21-8abe-4f4d-8c4f-cd7d5607f20e)
### 模型管理
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/551121ee-1bfe-421b-a9d1-24125126ab4b)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/871f2d2a-7e41-4be7-9b32-be1b3e00dc3e)
### 下载管理

View File

@@ -4,12 +4,14 @@ import (
"bufio"
"context"
"errors"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/minio/selfupdate"
@@ -43,7 +45,8 @@ func (a *App) OnStartup(ctx context.Context) {
a.cmdPrefix = "cd " + a.exDir + " && "
}
os.Chmod("./backend-rust/webgpu_server", 0777)
os.Chmod(a.exDir+"backend-rust/webgpu_server", 0777)
os.Chmod(a.exDir+"backend-rust/web-rwkv-converter", 0777)
os.Mkdir(a.exDir+"models", os.ModePerm)
os.Mkdir(a.exDir+"lora-models", os.ModePerm)
os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm)
@@ -53,6 +56,7 @@ func (a *App) OnStartup(ctx context.Context) {
}
a.downloadLoop()
a.midiLoop()
a.watchFs()
a.monitorHardware()
}
@@ -67,8 +71,8 @@ func (a *App) OnBeforeClose(ctx context.Context) bool {
func (a *App) watchFs() {
watcher, err := fsnotify.NewWatcher()
if err == nil {
watcher.Add("./lora-models")
watcher.Add("./models")
watcher.Add(a.exDir + "./lora-models")
watcher.Add(a.exDir + "./models")
go func() {
for {
select {
@@ -118,13 +122,50 @@ func (a *App) monitorHardware() {
monitor.Start()
}
type ProgressReader struct {
reader io.Reader
total int64
err error
}
func (pr *ProgressReader) Read(p []byte) (n int, err error) {
n, err = pr.reader.Read(p)
pr.err = err
pr.total += int64(n)
return
}
func (a *App) UpdateApp(url string) (broken bool, err error) {
resp, err := http.Get(url)
if err != nil {
return false, err
}
defer resp.Body.Close()
err = selfupdate.Apply(resp.Body, selfupdate.Options{})
pr := &ProgressReader{reader: resp.Body}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
go func() {
for {
<-ticker.C
wruntime.EventsEmit(a.ctx, "updateApp", &DownloadStatus{
Name: filepath.Base(url),
Path: "",
Url: url,
Transferred: pr.total,
Size: resp.ContentLength,
Speed: 0,
Progress: 100 * (float64(pr.total) / float64(resp.ContentLength)),
Downloading: pr.err == nil && pr.total < resp.ContentLength,
Done: pr.total == resp.ContentLength,
})
if pr.err != nil || pr.total == resp.ContentLength {
break
}
}
}()
err = selfupdate.Apply(pr, selfupdate.Options{})
if err != nil {
if rerr := selfupdate.RollbackError(err); rerr != nil {
return true, rerr

View File

@@ -33,9 +33,9 @@ type DownloadStatus struct {
var downloadList []*DownloadStatus
func existsInDownloadList(url string) bool {
func existsInDownloadList(path string, url string) bool {
for _, ds := range downloadList {
if ds.Url == url {
if ds.Path == path || ds.Url == url {
return true
}
}
@@ -88,7 +88,7 @@ func (a *App) ContinueDownload(url string) {
}
func (a *App) AddToDownloadList(path string, url string) {
if !existsInDownloadList(url) {
if !existsInDownloadList(a.exDir+path, url) {
downloadList = append(downloadList, &DownloadStatus{
resp: nil,
Name: filepath.Base(path),

View File

@@ -14,6 +14,13 @@ import (
wruntime "github.com/wailsapp/wails/v2/pkg/runtime"
)
func (a *App) SaveFile(path string, savedContent []byte) error {
if err := os.WriteFile(a.exDir+path, savedContent, 0644); err != nil {
return err
}
return nil
}
func (a *App) SaveJson(fileName string, jsonData any) error {
text, err := json.MarshalIndent(jsonData, "", " ")
if err != nil {
@@ -195,3 +202,12 @@ func (a *App) OpenFileFolder(path string, relative bool) error {
}
return errors.New("unsupported OS")
}
func (a *App) StartFile(path string) error {
cmd, err := CmdHelper(true, path)
if err != nil {
return err
}
err = cmd.Start()
return err
}

170
backend-golang/midi.go Normal file
View File

@@ -0,0 +1,170 @@
package backend_golang
import (
"errors"
"fmt"
"time"
"github.com/mattrtaylor/go-rtmidi"
"github.com/wailsapp/wails/v2/pkg/runtime"
)
type Port struct {
Name string `json:"name"`
}
type MIDIMessage struct {
MessageType string `json:"messageType"`
Channel int `json:"channel"`
Note int `json:"note"`
Velocity int `json:"velocity"`
Control int `json:"control"`
Value int `json:"value"`
}
var ports []Port
var input rtmidi.MIDIIn
var out rtmidi.MIDIOut
var activeIndex int = -1
var lastNoteTime time.Time
func (a *App) midiLoop() {
var err error
input, err = rtmidi.NewMIDIInDefault()
if err != nil {
runtime.EventsEmit(a.ctx, "midiError", err.Error())
return
}
out, err = rtmidi.NewMIDIOutDefault()
if err != nil {
runtime.EventsEmit(a.ctx, "midiError", err.Error())
}
err = out.OpenPort(0, "")
if err != nil {
runtime.EventsEmit(a.ctx, "midiError", err.Error())
}
ticker := time.NewTicker(500 * time.Millisecond)
go func() {
for {
<-ticker.C
count, err := input.PortCount()
if err != nil {
continue
}
ports = make([]Port, count)
for i := 0; i < count; i++ {
name, err := input.PortName(i)
if err == nil {
ports[i].Name = name
}
}
runtime.EventsEmit(a.ctx, "midiPorts", &ports)
}
}()
}
func (a *App) OpenMidiPort(index int) error {
if input == nil {
return errors.New("failed to initialize MIDI input")
}
if activeIndex == index {
return nil
}
input.Destroy()
var err error
input, err = rtmidi.NewMIDIInDefault()
if err != nil {
return err
}
err = input.SetCallback(func(msg rtmidi.MIDIIn, bytes []byte, t float64) {
// https://www.midi.org/specifications-old/item/table-1-summary-of-midi-message
// https://www.rfc-editor.org/rfc/rfc6295.html
//
// msgType channel
// 1001 0000
//
msgType := bytes[0] >> 4
channel := bytes[0] & 0x0f
switch msgType {
case 0x8:
elapsed := time.Since(lastNoteTime)
lastNoteTime = time.Now()
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
MessageType: "ElapsedTime",
Value: int(elapsed.Milliseconds()),
})
note := bytes[1]
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
MessageType: "NoteOff",
Channel: int(channel),
Note: int(note),
})
case 0x9:
elapsed := time.Since(lastNoteTime)
lastNoteTime = time.Now()
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
MessageType: "ElapsedTime",
Value: int(elapsed.Milliseconds()),
})
note := bytes[1]
velocity := bytes[2]
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
MessageType: "NoteOn",
Channel: int(channel),
Note: int(note),
Velocity: int(velocity),
})
case 0xb:
// control 12 => K1 knob, control 13 => K2 knob
control := bytes[1]
value := bytes[2]
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
MessageType: "ControlChange",
Channel: int(channel),
Control: int(control),
Value: int(value),
})
default:
fmt.Printf("Unknown midi message: %v\n", bytes)
}
})
if err != nil {
return err
}
err = input.OpenPort(index, "")
if err != nil {
return err
}
activeIndex = index
lastNoteTime = time.Now()
return nil
}
func (a *App) CloseMidiPort() error {
if input == nil {
return errors.New("failed to initialize MIDI input")
}
if activeIndex == -1 {
return nil
}
activeIndex = -1
input.Destroy()
var err error
input, err = rtmidi.NewMIDIInDefault()
if err != nil {
return err
}
return nil
}
func (a *App) PlayNote(msg MIDIMessage) error {
if out == nil {
return errors.New("failed to initialize MIDI output")
}
channelByte := byte(msg.Channel)
if msg.MessageType == "NoteOn" {
out.SendMessage([]byte{0x90 | channelByte, byte(msg.Note), byte(msg.Velocity)})
} else if msg.MessageType == "NoteOff" {
out.SendMessage([]byte{0x80 | channelByte, byte(msg.Note), byte(msg.Velocity)})
}
return nil
}

View File

@@ -10,7 +10,7 @@ import (
"strings"
)
func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (string, error) {
func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool) (string, error) {
var err error
if python == "" {
python, err = GetPython()
@@ -19,9 +19,15 @@ func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (
return "", err
}
args := []string{python, "./backend-python/main.py"}
if webui {
args = append(args, "--webui")
}
if rwkvBeta {
args = append(args, "--rwkv-beta")
}
if rwkvcpp {
args = append(args, "--rwkv.cpp")
}
args = append(args, "--port", strconv.Itoa(port), "--host", host)
return Cmd(args...)
}
@@ -43,7 +49,13 @@ func (a *App) ConvertModel(python string, modelPath string, strategy string, out
return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy)
}
func (a *App) ConvertSafetensors(python string, modelPath string, outPath string) (string, error) {
func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, error) {
args := []string{"./backend-rust/web-rwkv-converter"}
args = append(args, "--input", modelPath, "--output", outPath)
return Cmd(args...)
}
func (a *App) ConvertGGML(python string, modelPath string, outPath string, Q51 bool) (string, error) {
var err error
if python == "" {
python, err = GetPython()
@@ -51,7 +63,11 @@ func (a *App) ConvertSafetensors(python string, modelPath string, outPath string
if err != nil {
return "", err
}
return Cmd(python, "./backend-python/convert_safetensors.py", "--input", modelPath, "--output", outPath)
dataType := "FP16"
if Q51 {
dataType = "Q5_1"
}
return Cmd(python, "./backend-python/convert_pytorch_to_ggml.py", modelPath, outPath, dataType)
}
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {

View File

@@ -5,40 +5,60 @@ import (
"bufio"
"embed"
"errors"
"fmt"
"io"
"io/fs"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
)
func CmdHelper(hideWindow bool, args ...string) (*exec.Cmd, error) {
if runtime.GOOS != "windows" {
return nil, errors.New("unsupported OS")
}
filename := "./cmd-helper.bat"
_, err := os.Stat(filename)
if err != nil {
if err := os.WriteFile(filename, []byte("start %*"), 0644); err != nil {
return nil, err
}
}
cmdHelper, err := filepath.Abs(filename)
if err != nil {
return nil, err
}
if strings.Contains(cmdHelper, " ") {
for _, arg := range args {
if strings.Contains(arg, " ") {
return nil, errors.New("path contains space") // golang bug https://github.com/golang/go/issues/17149#issuecomment-473976818
}
}
}
cmd := exec.Command(cmdHelper, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{}
//go:custom_build windows cmd.SysProcAttr.HideWindow = hideWindow
return cmd, nil
}
func Cmd(args ...string) (string, error) {
switch platform := runtime.GOOS; platform {
case "windows":
if err := os.WriteFile("./cmd-helper.bat", []byte("start %*"), 0644); err != nil {
return "", err
}
cmdHelper, err := filepath.Abs("./cmd-helper")
cmd, err := CmdHelper(true, args...)
if err != nil {
return "", err
}
if strings.Contains(cmdHelper, " ") {
for _, arg := range args {
if strings.Contains(arg, " ") {
return "", errors.New("path contains space") // golang bug https://github.com/golang/go/issues/17149#issuecomment-473976818
}
}
}
cmd := exec.Command(cmdHelper, args...)
out, err := cmd.CombinedOutput()
_, err = cmd.CombinedOutput()
if err != nil {
return "", err
}
return string(out), nil
return "", nil
case "darwin":
ex, err := os.Executable()
if err != nil {
@@ -205,3 +225,12 @@ func Unzip(source, destination string) error {
}
return nil
}
func (a *App) IsPortAvailable(port int) bool {
l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%s", strconv.Itoa(port)))
if err != nil {
return false
}
defer l.Close()
return true
}

View File

@@ -231,5 +231,6 @@ try:
convert_and_save_and_exit=args.out,
)
except Exception as e:
print(e)
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -0,0 +1,169 @@
# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16
# Get model checkpoints from https://huggingface.co/BlinkDL
# See FILE_FORMAT.md for the documentation on the file format.
import argparse
import struct
import torch
from typing import Dict
def parse_args():
parser = argparse.ArgumentParser(
description="Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file"
)
parser.add_argument("src_path", help="Path to PyTorch checkpoint file")
parser.add_argument(
"dest_path", help="Path to rwkv.cpp checkpoint file, will be overwritten"
)
parser.add_argument(
"data_type",
help="Data type, FP16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0",
type=str,
choices=[
"FP16",
"Q4_0",
"Q4_1",
"Q5_0",
"Q5_1",
"Q8_0",
],
default="FP16",
)
return parser.parse_args()
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
n_layer: int = 0
while f"blocks.{n_layer}.ln1.weight" in state_dict:
n_layer += 1
assert n_layer > 0
return n_layer
def write_state_dict(
state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str
) -> None:
emb_weight: torch.Tensor = state_dict["emb.weight"]
n_layer: int = get_layer_count(state_dict)
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]
is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict
is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict
if is_v5_2:
print("Detected RWKV v5.2")
elif is_v5_1_or_2:
print("Detected RWKV v5.1")
else:
print("Detected RWKV v4")
with open(dest_path, "wb") as out_file:
is_FP16: bool = data_type == "FP16" or data_type == "float16"
out_file.write(
struct.pack(
# Disable padding with '='
"=iiiiii",
# Magic: 'ggmf' in hex
0x67676D66,
101,
n_vocab,
n_embed,
n_layer,
1 if is_FP16 else 0,
)
)
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()
if ".time_" in k:
tensor = tensor.squeeze()
if is_v5_1_or_2:
if ".time_decay" in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
else:
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
if ".time_first" in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)
if ".time_faaaa" in k:
tensor = tensor.unsqueeze(-1)
else:
if ".time_decay" in k:
tensor = -torch.exp(tensor)
# Keep 1-dim vectors and small matrices in FP32
if is_FP16 and len(tensor.shape) > 1 and ".time_" not in k:
tensor = tensor.half()
shape = tensor.shape
print(f"Writing {k}, shape {shape}, type {tensor.dtype}")
k_encoded: bytes = k.encode("utf-8")
out_file.write(
struct.pack(
"=iii",
len(shape),
len(k_encoded),
1 if tensor.dtype == torch.float16 else 0,
)
)
# Dimension order is reversed here:
# * PyTorch shape is (x rows, y columns)
# * ggml shape is (y elements in a row, x elements in a column)
# Both shapes represent the same tensor.
for dim in reversed(tensor.shape):
out_file.write(struct.pack("=i", dim))
out_file.write(k_encoded)
tensor.numpy().tofile(out_file)
def main() -> None:
args = parse_args()
print(f"Reading {args.src_path}")
state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location="cpu")
temp_output: str = args.dest_path
if args.data_type.startswith("Q"):
import re
temp_output = re.sub(r"Q[4,5,8]_[0,1]", "fp16", temp_output)
write_state_dict(state_dict, temp_output, "FP16")
if args.data_type.startswith("Q"):
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from rwkv_pip.cpp import rwkv_cpp_shared_library
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
library.rwkv_quantize_model_file(temp_output, args.dest_path, args.data_type)
print("Done")
if __name__ == "__main__":
try:
main()
except Exception as e:
print(e)
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -25,7 +25,7 @@ def rename_key(rename, name):
return name
def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
@@ -34,12 +34,14 @@ def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename=
# for k, v in loaded.items():
# print(f'{k}\t{v.shape}\t{v.dtype}')
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
# For tensors to be contiguous
for k, v in loaded.items():
for transpose_name in transpose_names:
if transpose_name in k:
loaded[k] = v.transpose(0, 1)
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()}
for k, v in loaded.items():
print(f"{k}\t{v.shape}\t{v.dtype}")
@@ -60,10 +62,21 @@ if __name__ == "__main__":
convert_file(
args.input,
args.output,
["lora_A"],
{"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"},
rename={
"time_faaaa": "time_first",
"time_maa": "time_mix",
"lora_A": "lora.0",
"lora_B": "lora.1",
},
transpose_names=[
"time_mix_w1",
"time_mix_w2",
"time_decay_w1",
"time_decay_w2",
],
)
print(f"Saved to {args.output}")
except Exception as e:
print(e)
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -4,6 +4,7 @@ Args = "args"
Model = "model"
Model_Status = "model_status"
Model_Config = "model_config"
Deploy_Mode = "deploy_mode"
class ModelStatus(Enum):
@@ -16,6 +17,7 @@ def init():
global GLOBALS
GLOBALS = {}
set(Model_Status, ModelStatus.Offline)
set(Deploy_Mode, False)
def set(key, value):

View File

@@ -2,17 +2,58 @@ import time
start_time = time.time()
import setuptools # avoid warnings
import argparse
from typing import Union, Sequence
def get_args(args: Union[Sequence[str], None] = None):
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="server arguments")
group.add_argument(
"--port",
type=int,
default=8000,
help="port to run the server on (default: 8000)",
)
group.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="host to run the server on (default: 127.0.0.1)",
)
group = parser.add_argument_group(title="mode arguments")
group.add_argument(
"--webui",
action="store_true",
help="whether to enable WebUI (default: False)",
)
group.add_argument(
"--rwkv-beta",
action="store_true",
help="whether to use rwkv-beta (default: False)",
)
group.add_argument(
"--rwkv.cpp",
action="store_true",
help="whether to use rwkv.cpp (default: False)",
)
args = parser.parse_args(args)
return args
if __name__ == "__main__":
args = get_args()
import os
import sys
import argparse
from typing import Sequence
from contextlib import asynccontextmanager
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import psutil
from fastapi import Depends, FastAPI
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
@@ -48,6 +89,35 @@ app.include_router(misc.router)
app.include_router(state_cache.router)
@app.post("/exit", tags=["Root"])
def exit():
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
parent_pid = os.getpid()
parent = psutil.Process(parent_pid)
for child in parent.children(recursive=True):
child.kill()
parent.kill()
try:
if (
"RWKV_RUNNER_PARAMS" in os.environ
and "--webui" in os.environ["RWKV_RUNNER_PARAMS"].split(" ")
) or args.webui:
from webui_server import webui_server
app.mount("/", webui_server)
except NameError:
pass
@app.get("/", tags=["Root"])
def read_root():
return {"Hello": "World!"}
def init():
global_var.init()
cmd_params = os.environ["RWKV_RUNNER_PARAMS"]
@@ -63,48 +133,7 @@ def init():
ngrok_connect()
@app.get("/", tags=["Root"])
def read_root():
return {"Hello": "World!"}
@app.post("/exit", tags=["Root"])
def exit():
parent_pid = os.getpid()
parent = psutil.Process(parent_pid)
for child in parent.children(recursive=True):
child.kill()
parent.kill()
def get_args(args: Union[Sequence[str], None] = None):
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="server arguments")
group.add_argument(
"--port",
type=int,
default=8000,
help="port to run the server on (default: 8000)",
)
group.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="host to run the server on (default: 127.0.0.1)",
)
group = parser.add_argument_group(title="mode arguments")
group.add_argument(
"--rwkv-beta",
action="store_true",
help="whether to use rwkv-beta (default: False)",
)
args = parser.parse_args(args)
return args
if __name__ == "__main__":
args = get_args()
os.environ["RWKV_RUNNER_PARAMS"] = " ".join(sys.argv[1:])
print("--- %s seconds ---" % (time.time() - start_time))
uvicorn.run("main:app", port=args.port, host=args.host, workers=1)

Binary file not shown.

View File

@@ -35,6 +35,11 @@ default_stop = [
"\n\nQ",
"\n\nHuman",
"\n\nBob",
"\n\nAssistant",
"\n\nAnswer",
"\n\nA",
"\n\nBot",
"\n\nAlice",
]
@@ -53,8 +58,8 @@ class ChatCompletionBody(ModelConfigBody):
True, description="Whether to insert default system prompt at the beginning"
)
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"messages": [
{"role": Role.User.value, "content": "hello", "raw": False}
@@ -72,6 +77,7 @@ class ChatCompletionBody(ModelConfigBody):
"frequency_penalty": 0.4,
}
}
}
class CompletionBody(ModelConfigBody):
@@ -80,8 +86,8 @@ class CompletionBody(ModelConfigBody):
stream: bool = False
stop: Union[str, List[str], None] = None
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"prompt": "The following is an epic science fiction masterpiece that is immortalized, "
+ "with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n",
@@ -95,6 +101,7 @@ class CompletionBody(ModelConfigBody):
"frequency_penalty": 0.4,
}
}
}
completion_lock = Lock()
@@ -376,8 +383,8 @@ class EmbeddingsBody(BaseModel):
encoding_format: str = None
fast_mode: bool = False
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"input": "a big apple",
"model": "rwkv",
@@ -385,6 +392,7 @@ class EmbeddingsBody(BaseModel):
"fast_mode": False,
}
}
}
def embedding_base64(embedding: List[float]) -> str:

View File

@@ -15,20 +15,29 @@ class SwitchModelBody(BaseModel):
strategy: str
tokenizer: Union[str, None] = None
customCuda: bool = False
deploy: bool = Field(
False,
description="Deploy mode. If success, will disable /switch-model, /exit and other dangerous APIs (state cache APIs, part of midi APIs)",
)
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
"strategy": "cuda fp16",
"tokenizer": None,
"tokenizer": "",
"customCuda": False,
"deploy": False,
}
}
}
@router.post("/switch-model", tags=["Configs"])
def switch_model(body: SwitchModelBody, response: Response, request: Request):
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(Status.HTTP_403_FORBIDDEN)
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = Status.HTTP_304_NOT_MODIFIED
return
@@ -40,13 +49,20 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "":
return "success"
if "->" in body.strategy:
state_cache.disable_state_cache()
else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
devices = set(
[
x.strip().split(" ")[0].replace("cuda:0", "cuda")
for x in body.strategy.split("->")
]
)
print(f"Strategy Devices: {devices}")
# if len(devices) > 1:
# state_cache.disable_state_cache()
# else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
@@ -58,12 +74,18 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
)
except Exception as e:
print(e)
import traceback
print(traceback.format_exc())
quick_log(request, body, f"Exception: {e}")
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(
Status.HTTP_500_INTERNAL_SERVER_ERROR, f"failed to load: {e}"
)
if body.deploy:
global_var.set(global_var.Deploy_Mode, True)
if global_var.get(global_var.Model_Config) is None:
global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))

View File

@@ -1,5 +1,6 @@
import io
from fastapi import APIRouter, HTTPException, status
import global_var
from fastapi import APIRouter, HTTPException, UploadFile, status
from starlette.responses import StreamingResponse
from pydantic import BaseModel
from utils.midi import *
@@ -11,12 +12,13 @@ router = APIRouter()
class TextToMidiBody(BaseModel):
text: str
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"text": "p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:2d:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:2d:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:1f:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:26:a g:39:a g:39:a g:3e:a g:3e:a g:42:a g:42:a pi:39:a pi:3e:a pi:42:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0",
}
}
}
@router.post("/text-to-midi", tags=["MIDI"])
@@ -31,21 +33,35 @@ def text_to_midi(body: TextToMidiBody):
return StreamingResponse(mid_data, media_type="audio/midi")
@router.post("/midi-to-text", tags=["MIDI"])
async def midi_to_text(file_data: UploadFile):
vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config)
mid = mido.MidiFile(file=file_data.file)
text = convert_midi_to_str(cfg, mid)
return {"text": text}
class TxtToMidiBody(BaseModel):
txt_path: str
midi_path: str
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"txt_path": "midi/sample.txt",
"midi_path": "midi/sample.mid",
}
}
}
@router.post("/txt-to-midi", tags=["MIDI"])
def txt_to_midi(body: TxtToMidiBody):
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if not body.midi_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
@@ -65,14 +81,15 @@ class MidiToWavBody(BaseModel):
wav_path: str
sound_font_path: str = "assets/default_sound_font.sf2"
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"midi_path": "midi/sample.mid",
"wav_path": "midi/sample.wav",
"sound_font_path": "assets/default_sound_font.sf2",
}
}
}
@router.post("/midi-to-wav", tags=["MIDI"])
@@ -81,6 +98,9 @@ def midi_to_wav(body: MidiToWavBody):
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
"""
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if not body.wav_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
@@ -95,14 +115,15 @@ class TextToWavBody(BaseModel):
wav_name: str
sound_font_path: str = "assets/default_sound_font.sf2"
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"text": "p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:2d:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:2d:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:1f:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:26:a g:39:a g:39:a g:3e:a g:3e:a g:42:a g:42:a pi:39:a pi:3e:a pi:42:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0",
"wav_name": "sample",
"sound_font_path": "assets/default_sound_font.sf2",
}
}
}
@router.post("/text-to-wav", tags=["MIDI"])
@@ -111,6 +132,9 @@ def text_to_wav(body: TextToWavBody):
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
"""
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
text = body.text.strip()
if not text.startswith("<start>"):
text = "<start> " + text

View File

@@ -4,12 +4,13 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
import global_var
router = APIRouter()
trie = None
dtrie: Dict = {}
max_trie_len = 3000
max_trie_len = 300
loop_start_id = 1 # to prevent preloaded prompts from being deleted
loop_del_trie_id = loop_start_id
@@ -36,16 +37,24 @@ def init():
def disable_state_cache():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
trie = None
dtrie = {}
gc.collect()
print("state cache disabled")
return "success"
@router.post("/enable-state-cache", tags=["State Cache"])
def enable_state_cache():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
try:
import cyac
@@ -53,8 +62,10 @@ def enable_state_cache():
dtrie = {}
gc.collect()
print("state cache enabled")
return "success"
except ModuleNotFoundError:
print("state cache disabled")
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
@@ -65,9 +76,13 @@ class AddStateBody(BaseModel):
logits: Any
@router.post("/add-state", tags=["State Cache"])
# @router.post("/add-state", tags=["State Cache"])
def add_state(body: AddStateBody):
global trie, dtrie, loop_del_trie_id
# if global_var.get(global_var.Deploy_Mode) is True:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -75,14 +90,17 @@ def add_state(body: AddStateBody):
try:
id: int = trie.insert(body.prompt)
device: torch.device = body.state[0].device
devices: List[torch.device] = [
(tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
for tensor in body.state
]
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": [tensor.cpu() for tensor in body.state]
if device != torch.device("cpu")
if hasattr(body.state[0], "device")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"device": device,
"devices": devices,
}
if len(trie) >= max_trie_len:
@@ -108,6 +126,10 @@ def add_state(body: AddStateBody):
@router.post("/reset-state", tags=["State Cache"])
def reset_state():
global trie, dtrie
if global_var.get(global_var.Deploy_Mode) is True:
raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -141,9 +163,13 @@ def __get_a_dtrie_buff_size(dtrie_v):
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
@router.post("/longest-prefix-state", tags=["State Cache"])
# @router.post("/longest-prefix-state", tags=["State Cache"])
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie
# if global_var.get(global_var.Deploy_Mode) is True:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
@@ -157,32 +183,29 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
pass
if id != -1:
v = dtrie[id]
device: torch.device = v["device"]
devices: List[torch.device] = v["devices"]
prompt: str = trie[id]
quick_log(request, body, "Hit:\n" + prompt)
return {
"prompt": prompt,
"tokens": v["tokens"],
"state": [tensor.to(device) for tensor in v["state"]]
if device != torch.device("cpu")
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
if hasattr(v["state"][0], "device")
else v["state"],
"logits": v["logits"],
"device": device.type,
}
else:
return {
"prompt": "",
"tokens": [],
"state": None,
"logits": None,
"device": None,
}
return {"prompt": "", "tokens": [], "state": None, "logits": None}
@router.post("/save-state", tags=["State Cache"])
# @router.post("/save-state", tags=["State Cache"])
def save_state():
global trie
# if global_var.get(global_var.Deploy_Mode) is True:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")

Binary file not shown.

BIN
backend-python/rwkv_pip/cpp/librwkv.so vendored Normal file

Binary file not shown.

14
backend-python/rwkv_pip/cpp/model.py vendored Normal file
View File

@@ -0,0 +1,14 @@
from typing import Any, List, Union
from . import rwkv_cpp_model
from . import rwkv_cpp_shared_library
class RWKV:
def __init__(self, model_path: str, strategy=None):
self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
self.w = {} # fake weight
self.w["emb.weight"] = [0] * self.model.n_vocab
def forward(self, tokens: List[int], state: Union[Any, None] = None):
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)

BIN
backend-python/rwkv_pip/cpp/rwkv.dll vendored Normal file

Binary file not shown.

View File

@@ -0,0 +1,369 @@
import os
import multiprocessing
# Pre-import PyTorch, if available.
# This fixes "OSError: [WinError 127] The specified procedure could not be found".
try:
import torch
except ModuleNotFoundError:
pass
# I'm sure this is not strictly correct, but let's keep this crutch for now.
try:
import rwkv_cpp_shared_library
except ModuleNotFoundError:
from . import rwkv_cpp_shared_library
from typing import TypeVar, Optional, Tuple, List
# A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
class RWKVModel:
"""
An RWKV model managed by rwkv.cpp library.
"""
def __init__(
self,
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
gpu_layer_count: int = 0,
**kwargs
) -> None:
"""
Loads the model and prepares it for inference.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library : RWKVSharedLibrary
rwkv.cpp shared library.
model_path : str
Path to RWKV model file in ggml format.
thread_count : int
Thread count to use. If not set, defaults to CPU count / 2.
gpu_layer_count : int
Count of layers to offload onto the GPU, must be >= 0.
See documentation of `gpu_offload_layers` for details about layer offloading.
"""
if 'gpu_layers_count' in kwargs:
gpu_layer_count = kwargs['gpu_layers_count']
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be > 0'
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
if gpu_layer_count > 0:
self.gpu_offload_layers(gpu_layer_count)
self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
self._valid: bool = True
def gpu_offload_layers(self, layer_count: int) -> bool:
"""
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
- pass `model.n_layer` to offload all layers except model head
- pass `model.n_layer + 1` to offload all layers, including model head
Returns true if at least one layer was offloaded.
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
Parameters
----------
layer_count : int
Count of layers to offload onto the GPU, must be >= 0.
"""
assert layer_count >= 0, 'Layer count must be >= 0'
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
@property
def n_vocab(self) -> int:
return self._library.rwkv_get_n_vocab(self._ctx)
@property
def n_embed(self) -> int:
return self._library.rwkv_get_n_embed(self._ctx)
@property
def n_layer(self) -> int:
return self._library.rwkv_get_n_layer(self._ctx)
def eval(
self,
token: int,
state_in: Optional[NumpyArrayOrPyTorchTensor],
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
use_numpy: bool = False
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
"""
Evaluates the model for a single token.
In case of any error, this method will throw an exception.
Parameters
----------
token : int
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
state_in : Optional[NumpyArrayOrTorchTensor]
State from previous call of this method. If this is a first pass, set it to None.
state_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
logits_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
use_numpy : bool
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
This parameter is ignored if any tensor parameter is not None; in such case,
type of returned tensors will match the type of received tensors.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self._valid, 'Model was freed'
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
if state_in is not None:
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
state_in_ptr = self._get_data_ptr(state_in)
else:
state_in_ptr = 0
if state_out is not None:
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
if logits_out is not None:
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
self._library.rwkv_eval(
self._ctx,
token,
state_in_ptr,
self._get_data_ptr(state_out),
self._get_data_ptr(logits_out)
)
return logits_out, state_out
def eval_sequence(
self,
tokens: List[int],
state_in: Optional[NumpyArrayOrPyTorchTensor],
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
use_numpy: bool = False
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
"""
Evaluates the model for a sequence of tokens.
NOTE ON GGML NODE LIMIT
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
this limit when using large models and/or large sequence lengths.
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
To get rid of the assertion failure, reduce the model size and/or sequence length.
In case of any error, this method will throw an exception.
Parameters
----------
tokens : List[int]
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
state_in : Optional[NumpyArrayOrTorchTensor]
State from previous call of this method. If this is a first pass, set it to None.
state_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
logits_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
use_numpy : bool
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
This parameter is ignored if any tensor parameter is not None; in such case,
type of returned tensors will match the type of received tensors.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self._valid, 'Model was freed'
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
if state_in is not None:
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
state_in_ptr = self._get_data_ptr(state_in)
else:
state_in_ptr = 0
if state_out is not None:
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
if logits_out is not None:
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
self._library.rwkv_eval_sequence(
self._ctx,
tokens,
state_in_ptr,
self._get_data_ptr(state_out),
self._get_data_ptr(logits_out)
)
return logits_out, state_out
def eval_sequence_in_chunks(
self,
tokens: List[int],
state_in: Optional[NumpyArrayOrPyTorchTensor],
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
chunk_size: int = 16,
use_numpy: bool = False
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
"""
Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
and choose one that works the best in your use case.
In case of any error, this method will throw an exception.
Parameters
----------
tokens : List[int]
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
chunk_size : int
Size of each chunk in tokens, must be positive.
state_in : Optional[NumpyArrayOrTorchTensor]
State from previous call of this method. If this is a first pass, set it to None.
state_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
logits_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
use_numpy : bool
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
This parameter is ignored if any tensor parameter is not None; in such case,
type of returned tensors will match the type of received tensors.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self._valid, 'Model was freed'
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
if state_in is not None:
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
state_in_ptr = self._get_data_ptr(state_in)
else:
state_in_ptr = 0
if state_out is not None:
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
if logits_out is not None:
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
self._library.rwkv_eval_sequence_in_chunks(
self._ctx,
tokens,
chunk_size,
state_in_ptr,
self._get_data_ptr(state_out),
self._get_data_ptr(logits_out)
)
return logits_out, state_out
def free(self) -> None:
"""
Frees all allocated resources.
In case of any error, this method will throw an exception.
The object must not be used anymore after calling this method.
"""
assert self._valid, 'Already freed'
self._valid = False
self._library.rwkv_free(self._ctx)
def __del__(self) -> None:
# Free the context on GC in case user forgot to call free() explicitly.
if hasattr(self, '_valid') and self._valid:
self.free()
def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
for tensor in tensors:
if tensor is not None:
return False if self._is_pytorch_tensor(tensor) else True
return use_numpy_by_default
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
if self._is_pytorch_tensor(tensor):
tensor: torch.Tensor = tensor
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
assert tensor.dtype == torch.float32, f'{name} is not of type float32'
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
assert tensor.is_contiguous(), f'{name} is not contiguous'
else:
import numpy as np
tensor: np.ndarray = tensor
assert tensor.dtype == np.float32, f'{name} is not of type float32'
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
assert tensor.data.contiguous, f'{name} is not contiguous'
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
if self._is_pytorch_tensor(tensor):
return tensor.data_ptr()
else:
return tensor.ctypes.data
def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
if use_numpy:
import numpy as np
return np.zeros(element_count, dtype=np.float32)
else:
return torch.zeros(element_count, dtype=torch.float32, device='cpu')

View File

@@ -0,0 +1,444 @@
import os
import sys
import ctypes
import pathlib
import platform
from typing import Optional, List, Tuple, Callable
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
'Q4_0',
'Q4_1',
'Q5_0',
'Q5_1',
'Q8_0'
)
P_FLOAT = ctypes.POINTER(ctypes.c_float)
P_INT = ctypes.POINTER(ctypes.c_int32)
class RWKVContext:
def __init__(self, ptr: ctypes.pointer) -> None:
self.ptr: ctypes.pointer = ptr
class RWKVSharedLibrary:
"""
Python wrapper around rwkv.cpp shared library.
"""
def __init__(self, shared_library_path: str) -> None:
"""
Loads the shared library from specified file.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library_path : str
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
"""
# When Python is greater than 3.8, we need to reprocess the custom dll
# according to the documentation to prevent loading failure errors.
# https://docs.python.org/3/whatsnew/3.8.html#ctypes
if platform.system().lower() == 'windows':
self.library = ctypes.CDLL(shared_library_path, winmode=0)
else:
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
self.library.rwkv_eval.argtypes = [
ctypes.c_void_p, # ctx
ctypes.c_int32, # token
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval.restype = ctypes.c_bool
self.library.rwkv_eval_sequence.argtypes = [
ctypes.c_void_p, # ctx
P_INT, # tokens
ctypes.c_size_t, # token count
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval_sequence.restype = ctypes.c_bool
self.library.rwkv_eval_sequence_in_chunks.argtypes = [
ctypes.c_void_p, # ctx
P_INT, # tokens
ctypes.c_size_t, # token count
ctypes.c_size_t, # chunk size
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_embed.restype = ctypes.c_size_t
self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_layer.restype = ctypes.c_size_t
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
self.library.rwkv_get_system_info_string.argtypes = []
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
self.nullptr = ctypes.cast(0, ctypes.c_void_p)
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
"""
Loads the model from a file and prepares it for inference.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
model_file_path : str
Path to model file in ggml format.
thread_count : int
Count of threads to use, must be positive.
"""
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
return RWKVContext(ptr)
def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
"""
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
- pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
- pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
Returns true if at least one layer was offloaded.
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
layer_count : int
Count of layers to offload onto the GPU, must be >= 0.
"""
assert layer_count >= 0, 'Layer count must be >= 0'
return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
def rwkv_eval(
self,
ctx: RWKVContext,
token: int,
state_in_address: Optional[int],
state_out_address: int,
logits_out_address: int
) -> None:
"""
Evaluates the model for a single token.
Throws an exception in case of any error. Error messages would be printed to stderr.
Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
token : int
Next token index, in range 0 <= token < n_vocab.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
logits_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""
assert self.library.rwkv_eval(
ctx.ptr,
ctypes.c_int32(token),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval failed, check stderr'
def rwkv_eval_sequence(
self,
ctx: RWKVContext,
tokens: List[int],
state_in_address: Optional[int],
state_out_address: int,
logits_out_address: int
) -> None:
"""
Evaluates the model for a sequence of tokens.
Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
NOTE ON GGML NODE LIMIT
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
this limit when using large models and/or large sequence lengths.
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
To get rid of the assertion failure, reduce the model size and/or sequence length.
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
tokens : List[int]
Next token indices, in range 0 <= token < n_vocab.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
logits_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""
assert self.library.rwkv_eval_sequence(
ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval_sequence failed, check stderr'
def rwkv_eval_sequence_in_chunks(
self,
ctx: RWKVContext,
tokens: List[int],
chunk_size: int,
state_in_address: Optional[int],
state_out_address: int,
logits_out_address: int
) -> None:
"""
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
and choose one that works the best in your use case.
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
tokens : List[int]
Next token indices, in range 0 <= token < n_vocab.
chunk_size : int
Size of each chunk in tokens, must be positive.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
logits_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""
assert self.library.rwkv_eval_sequence_in_chunks(
ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)),
ctypes.c_size_t(chunk_size),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval_sequence_in_chunks failed, check stderr'
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
"""
Returns the number of tokens in the given model's vocabulary.
Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_n_vocab(ctx.ptr)
def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
"""
Returns the number of elements in the given model's embedding.
Useful for reading individual fields of a model's hidden state.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_n_embed(ctx.ptr)
def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
"""
Returns the number of layers in the given model.
A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
Useful for always offloading the entire model to GPU.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_n_layer(ctx.ptr)
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
"""
Returns count of FP32 elements in state buffer.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
"""
Returns count of FP32 elements in logits buffer.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
def rwkv_free(self, ctx: RWKVContext) -> None:
"""
Frees all allocated memory and the context.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
self.library.rwkv_free(ctx.ptr)
ctx.ptr = self.nullptr
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
"""
Quantizes FP32 or FP16 model to one of INT4 formats.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
model_file_path_in : str
Path to model file in ggml format, must be either FP32 or FP16.
model_file_path_out : str
Quantized model will be written here.
format_name : str
One of QUANTIZED_FORMAT_NAMES.
"""
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
assert self.library.rwkv_quantize_model_file(
model_file_path_in.encode('utf-8'),
model_file_path_out.encode('utf-8'),
format_name.encode('utf-8')
), 'rwkv_quantize_model_file failed, check stderr'
def rwkv_get_system_info_string(self) -> str:
"""
Returns system information string.
"""
return self.library.rwkv_get_system_info_string().decode('utf-8')
def load_rwkv_shared_library() -> RWKVSharedLibrary:
"""
Attempts to find rwkv.cpp shared library and load it.
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
"""
file_name: str
if 'win32' in sys.platform or 'cygwin' in sys.platform:
file_name = 'rwkv.dll'
elif 'darwin' in sys.platform:
file_name = 'librwkv.dylib'
else:
file_name = 'librwkv.so'
# Possible sub-paths to the library relative to the repo dir.
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
# No lookup for Debug config here.
# I assume that if a user wants to debug the library,
# they will be able to find the library and set the exact path explicitly.
lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name,
lambda p: p / 'bin' / 'Release' / file_name,
lambda p: p / 'bin' / file_name,
# Some people prefer to build in the "build" subdirectory.
lambda p: p / 'build' / 'bin' / 'Release' / file_name,
lambda p: p / 'build' / 'bin' / file_name,
lambda p: p / 'build' / file_name,
# Fallback.
lambda p: p / file_name
]
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
parent_paths: List[pathlib.Path] = [
# Possible repo dirs relative to the working dir.
# ./python/rwkv_cpp
working_dir.parent.parent,
# ./python
working_dir.parent,
# .
working_dir,
# Repo dir relative to this Python file.
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
]
for parent_path in parent_paths:
for child_path in child_paths:
full_path: pathlib.Path = child_path(parent_path)
if os.path.isfile(full_path):
return RWKVSharedLibrary(str(full_path))
assert False, (f'Failed to find {file_name} automatically; '
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')

View File

@@ -3,6 +3,8 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#define CUBLAS_CHECK(condition) \
for (cublasStatus_t _cublas_check_status = (condition); \
@@ -18,26 +20,13 @@
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
" at " + std::to_string(__LINE__));
cublasHandle_t get_cublas_handle() {
static cublasHandle_t cublas_handle = []() {
cublasHandle_t handle = nullptr;
CUBLAS_CHECK(cublasCreate(&handle));
#if CUDA_VERSION < 11000
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
return handle;
}();
return cublas_handle;
}
/*
NOTE: blas gemm is column-major by default, but we need row-major output.
The data of row-major, transposed matrix is exactly the same as the
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
*/
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
const auto cuda_data_type = CUDA_R_16F;
const auto cuda_c_data_type =
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
@@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const int cublas_lda = m;
const int cublas_ldb = k;
const int cublas_ldc = m;
cublasHandle_t cublas_handle = get_cublas_handle();
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;

View File

@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
#include <c10/cuda/CUDAGuard.h>
typedef at::BFloat16 bf16;
typedef at::Half fp16;
typedef float fp32;
@@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
}
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
}

87
backend-python/rwkv_pip/cuda/rwkv6.cu vendored Normal file
View File

@@ -0,0 +1,87 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
typedef at::Half fp16;
typedef float fp32;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
F *__restrict__ const _y)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_u += h*_N_;
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
float state[_N_];
#pragma unroll
for (int j = 0; j < _N_; j++)
state[j] = _state[j];
__syncthreads();
u[i] = float(_u[i]);
__syncthreads();
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
{
__syncthreads();
w[i] = _w[t];
r[i] = float(_r[t]);
k[i] = float(_k[t]);
__syncthreads();
const float v = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j+=4)
{
const float4& r_ = (float4&)(r[j]);
const float4& k_ = (float4&)(k[j]);
const float4& w_ = (float4&)(w[j]);
const float4& u_ = (float4&)(u[j]);
float4& s = (float4&)(state[j]);
float4 x;
x.x = k_.x * v;
x.y = k_.y * v;
x.z = k_.z * v;
x.w = k_.w * v;
y += r_.x * (u_.x * x.x + s.x);
y += r_.y * (u_.y * x.y + s.y);
y += r_.z * (u_.z * x.z + s.z);
y += r_.w * (u_.w * x.w + s.w);
s.x = s.x * w_.x + x.x;
s.y = s.y * w_.y + x.y;
s.z = s.z * w_.z + x.z;
s.w = s.w * w_.w + x.w;
}
_y[t] = F(y);
}
#pragma unroll
for (int j = 0; j < _N_; j++)
_state[j] = state[j];
}
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
{
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
}
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
{
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
}
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
{
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
}

View File

@@ -0,0 +1,34 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
#include <c10/cuda/CUDAGuard.h>
typedef at::BFloat16 bf16;
typedef at::Half fp16;
typedef float fp32;
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
}
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
}
TORCH_LIBRARY(rwkv6, m) {
m.def("forward_bf16", forward_bf16);
m.def("forward_fp16", forward_fp16);
m.def("forward_fp32", forward_fp32);
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

BIN
backend-python/rwkv_pip/rwkv6.pyd vendored Normal file

Binary file not shown.

View File

@@ -78,11 +78,22 @@ class PIPELINE:
def decode(self, x):
return self.tokenizer.decode(x)
def np_softmax(self, x: np.ndarray, axis: int):
x -= x.max(axis=axis, keepdims=True)
e: np.ndarray = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
np_logits = type(logits) == np.ndarray
if np_logits:
probs = self.np_softmax(logits, axis=-1)
else:
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
if probs.device == torch.device("cpu"):
probs = probs.numpy()
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if np_logits or probs.device.type in ["cpu", "privateuseone"]:
if not np_logits:
probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)

Binary file not shown.

View File

@@ -10,7 +10,7 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s\n%(message)s")
fh = logging.handlers.RotatingFileHandler(
"api.log", mode="a", maxBytes=3 * 1024 * 1024, backupCount=3
"api.log", mode="a", maxBytes=3 * 1024 * 1024, backupCount=3, encoding="utf-8"
)
fh.setFormatter(formatter)
logger.addHandler(fh)

View File

@@ -1,11 +1,13 @@
import os
import sys
import global_var
def ngrok_connect():
from pyngrok import ngrok, conf
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
conf.set_default(
conf.PyngrokConfig(ngrok_path="./ngrok.exe" if os.name == "nt" else "./ngrok")
)
ngrok.set_auth_token(os.environ["ngrok_token"])
http_tunnel = ngrok.connect(8000 if len(sys.argv) == 1 else int(sys.argv[1]))
print(http_tunnel.public_url)
http_tunnel = ngrok.connect(global_var.get(global_var.Args).port)
print(f"ngrok url: {http_tunnel.public_url}")

View File

@@ -510,12 +510,22 @@ def get_tokenizer(tokenizer_len: int):
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
if "midi" in model.lower() or "abc" in model.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
print("Using rwkv-beta")
from rwkv_pip.beta.model import (
RWKV as Model,
)
elif rwkv_cpp:
print("Using rwkv.cpp, strategy is ignored")
from rwkv_pip.cpp.model import (
RWKV as Model,
)
else:
from rwkv_pip.model import (
RWKV as Model,
@@ -551,8 +561,8 @@ class ModelConfigBody(BaseModel):
presence_penalty: float = Field(default=None, ge=-2, le=2)
frequency_penalty: float = Field(default=None, ge=-2, le=2)
class Config:
json_schema_extra = {
model_config = {
"json_schema_extra": {
"example": {
"max_tokens": 1000,
"temperature": 1.2,
@@ -561,6 +571,7 @@ class ModelConfigBody(BaseModel):
"frequency_penalty": 0.4,
}
}
}
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):

View File

@@ -0,0 +1,14 @@
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.staticfiles import StaticFiles
import uvicorn
webui_server = FastAPI()
webui_server.add_middleware(GZipMiddleware, minimum_size=1000)
webui_server.mount(
"/", StaticFiles(directory="frontend/dist", html=True), name="static"
)
if __name__ == "__main__":
uvicorn.run("webui_server:webui_server")

View File

@@ -9,7 +9,7 @@ cd RWKV-Next-Web
git clone https://github.com/josStorer/RWKV-Runner --depth=1
python3 -m pip install torch torchvision torchaudio
python3 -m pip install -r RWKV-Runner/backend-python/requirements.txt
python3 ./RWKV-Runner/backend-python/main.py > log.txt &
python3 ./RWKV-Runner/backend-python/main.py > log.txt & # this is only an example, you should use screen or other tools to run it in background
if [ ! -d RWKV-Runner/models ]; then
mkdir RWKV-Runner/models
@@ -22,6 +22,6 @@ yarn install
yarn build
export PROXY_URL=""
export BASE_URL=http://127.0.0.1:8000
yarn start &
yarn start & # this is only an example, you should use screen or other tools to run it in background
curl http://127.0.0.1:8000/switch-model -X POST -H "Content-Type: application/json" -d '{"model":"./RWKV-Runner/models/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth","strategy":"cpu fp32"}'

View File

@@ -0,0 +1,19 @@
: install git python3.10 npm by yourself
: change model and strategy according to your hardware
git clone https://github.com/josStorer/RWKV-Runner --depth=1
python -m pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cu117
python -m pip install -r RWKV-Runner/backend-python/requirements.txt
cd RWKV-Runner/frontend
call npm ci
call npm run build
cd ..
: optional: set ngrok_token=YOUR_NGROK_TOKEN
start python ./backend-python/main.py --webui
start "C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe" "http://127.0.0.1:8000"
powershell -Command "(Test-Path ./models) -or (mkdir models)"
powershell -Command "Import-Module BitsTransfer"
powershell -Command "(Test-Path ./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth) -or (Start-BitsTransfer https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth ./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth)"
powershell -Command "Invoke-WebRequest http://127.0.0.1:8000/switch-model -Method POST -ContentType 'application/json' -Body '{\"model\":\"./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth\",\"strategy\":\"cuda fp32 *20+\",\"deploy\":\"true\"}'"

View File

@@ -0,0 +1,22 @@
# install git python3.10 npm by yourself
# change model and strategy according to your hardware
sudo apt install python3-dev
git clone https://github.com/josStorer/RWKV-Runner --depth=1
python3 -m pip install torch torchvision torchaudio
python3 -m pip install -r RWKV-Runner/backend-python/requirements.txt
cd RWKV-Runner/frontend
npm ci
npm run build
cd ..
# optional: export ngrok_token=YOUR_NGROK_TOKEN
python3 ./backend-python/main.py --webui > log.txt & # this is only an example, you should use screen or other tools to run it in background
if [ ! -d models ]; then
mkdir models
fi
wget -N https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth -P models/
curl http://127.0.0.1:8000/switch-model -X POST -H "Content-Type: application/json" -d '{"model":"./models/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth","strategy":"cpu fp32","deploy":"true"}'

View File

@@ -23,6 +23,7 @@ def file_cleaner(file):
return cleaner
expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100
model_file = open(sys.argv[1], "rb")
cleaner = file_cleaner(model_file)
cleaner_thread = threading.Thread(target=cleaner, daemon=True)
@@ -34,8 +35,23 @@ gc.collect()
n_embd = w["emb.weight"].shape[1]
n_layer = 0
keys = list(w.keys())
version = 4
for x in keys:
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
n_layer = max(n_layer, layer_id + 1)
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
if "ln_x" in x:
version = max(5, version)
if "gate.weight" in x:
version = max(5.1, version)
if int(version) == 5 and "att.time_decay" in x:
if len(w[x].shape) > 1:
if w[x].shape[1] > 1:
version = max(5.2, version)
if "time_maa" in x:
version = max(6, version)
if version <= expected_max_version:
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
else:
raise Exception(f"RWKV{version} is not supported")

View File

@@ -47,8 +47,12 @@ else
fi
echo "loading $loadModel"
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel)
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4)
echo $modelInfo
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
if [[ $modelInfo =~ "--n_layer" ]]; then
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
else
echo "modelInfo is invalid"
exit 1
fi

View File

@@ -246,5 +246,6 @@ if __name__ == "__main__":
try:
main()
except Exception as e:
print(e)
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -64,5 +64,6 @@ try:
torch.save(output_w, output)
except Exception as e:
print(e)
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -264,7 +264,7 @@ if __name__ == "__main__":
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#

View File

@@ -1,3 +1,3 @@
torch==1.13.1
pytorch_lightning==1.9.5
deepspeed
deepspeed==0.11.2

View File

@@ -1,9 +1,10 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta content="width=device-width, initial-scale=1.0" name="viewport"/>
<title>RWKV-Runner</title>
<meta charset="UTF-8" />
<meta content="width=device-width, initial-scale=1.0" name="viewport" />
<title>RWKV-Runner</title>
<link href="./src/assets/images/logo.png" rel="icon" type="image/x-icon">
</head>
<body>
<div id="root"></div>

File diff suppressed because it is too large Load Diff

View File

@@ -16,15 +16,17 @@
"@primer/octicons-react": "^19.1.0",
"chart.js": "^4.3.0",
"classnames": "^2.3.2",
"github-markdown-css": "^5.2.0",
"file-saver": "^2.0.5",
"html-midi-player": "^1.5.0",
"i18next": "^22.4.15",
"mobx": "^6.9.0",
"mobx-react-lite": "^3.4.3",
"pdfjs-dist": "^4.0.189",
"react": "^18.2.0",
"react-beautiful-dnd": "^13.1.1",
"react-chartjs-2": "^5.2.0",
"react-dom": "^18.2.0",
"react-draggable": "^4.4.6",
"react-i18next": "^12.2.2",
"react-markdown": "^8.0.7",
"react-router": "^6.11.1",
@@ -38,6 +40,7 @@
"uuid": "^9.0.0"
},
"devDependencies": {
"@types/file-saver": "^2.0.7",
"@types/react": "^18.2.6",
"@types/react-beautiful-dnd": "^13.1.4",
"@types/react-dom": "^18.2.4",
@@ -49,6 +52,7 @@
"sass": "^1.62.1",
"tailwindcss": "^3.3.2",
"typescript": "^5.0.4",
"vite": "^4.3.6"
"vite": "^4.3.6",
"vite-plugin-top-level-await": "^1.3.1"
}
}

View File

@@ -26,18 +26,22 @@
import { FluentProvider, Tab, TabList, webDarkTheme, webLightTheme } from '@fluentui/react-components';
import { FC, useEffect, useState } from 'react';
import { Route, Routes, useLocation, useNavigate } from 'react-router';
import { pages } from './pages';
import { pages as clientPages } from './pages';
import { useMediaQuery } from 'usehooks-ts';
import commonStore from './stores/commonStore';
import { observer } from 'mobx-react-lite';
import { useTranslation } from 'react-i18next';
import { CustomToastContainer } from './components/CustomToastContainer';
import { LazyImportComponent } from './components/LazyImportComponent';
const App: FC = observer(() => {
const { t } = useTranslation();
const navigate = useNavigate();
const location = useLocation();
const mq = useMediaQuery('(min-width: 640px)');
const pages = commonStore.platform === 'web' ? clientPages.filter(page =>
!['/configs', '/models', '/downloads', '/train', '/about'].some(path => page.path === path)
) : clientPages;
const [path, setPath] = useState<string>(pages[0].path);
@@ -47,10 +51,10 @@ const App: FC = observer(() => {
useEffect(() => setPath(location.pathname), [location]);
return (
<FluentProvider className="h-screen"
<FluentProvider
theme={commonStore.settings.darkMode ? webDarkTheme : webLightTheme}
data-theme={commonStore.settings.darkMode ? 'dark' : 'light'}>
<div className="flex h-full">
<div className="flex h-screen">
<div className="flex flex-col w-16 sm:w-48 p-2 justify-between">
<TabList
size="large"
@@ -82,7 +86,7 @@ const App: FC = observer(() => {
<div className="h-full w-full p-2 box-border overflow-y-hidden">
<Routes>
{pages.map(({ path, element }, index) => (
<Route key={`${path}-${index}`} path={path} element={element} />
<Route key={`${path}-${index}`} path={path} element={<LazyImportComponent lazyChildren={element} />} />
))}
</Routes>
</div>

View File

@@ -13,7 +13,7 @@
"Working": "動作中",
"Stop": "停止",
"Enable High Precision For Last Layer": "最後の層で高精度を有効にする",
"Stored Layers": "メモリ層読み込み",
"Stored Layers": "保存されるレイヤー",
"Precision": "精度",
"Device": "デバイス",
"Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.": "これらの設定でモデルを変換します。変換されたモデルを使用すると、読み込み速度が大幅に向上しますが、変換したモデルのパラメータを変更することはできません。",
@@ -82,7 +82,7 @@
"Just like feeding sedatives to the model. Consider the results of the top n% probability mass, 0.1 considers the top 10%, with higher quality but more conservative, 1 considers all results, with lower quality but more diverse.": "モデルに鎮静剤を与えるようなもの。上位nの確率質量の結果を考えてみてください。0.1は上位10を考えており、質が高いが保守的で、1は全ての結果を考慮しており、質は低いが多様性があります。",
"Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.": "ポジティヴ値は、新しいトークンが今までのテキストに出現していたかどうかに基づいてこれらをペナルティとし、新しいトピックについて話す可能性を増加させます。",
"Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.": "ポジティブ値は、新しいトークンが既存のテキストでどれだけ頻繁に使われているかに基づいてペナルティを与え、モデルが同じ行を完全に繰り返す可能性を減らします。",
"int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality, and fp32 has the best quality.": "int8はVRAMの使用量が少ないですが、質が若干低いです。fp16は高品質、fp32は最高品質です。",
"int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality.": "int8はVRAMの使用量が少ないですが、質が若干低いです。fp16は高品質。",
"Number of the neural network layers loaded into VRAM, the more you load, the faster the speed, but it consumes more VRAM. (If your VRAM is not enough, it will fail to load)": "VRAMにロードされるニューラルネットワークの層の数。ロードする量が多いほど速度は速くなりますが、VRAMを多く消費します。(VRAMが不足している場合、ロードに失敗します)",
"Whether to use CPU to calculate the last output layer of the neural network with FP32 precision to obtain better quality.": "ネットワークの最終出力層をFP32精度で計算するためにCPUを使用するかどうか。",
"Downloads": "ダウンロード",
@@ -128,7 +128,7 @@
"Chinese Kongfu": "中国武術",
"Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)",
"Custom": "カスタム",
"CUDA (Beta, Faster)": "CUDA (ベータ、高速)",
"CUDA (Beta, Faster)": "CUDA (Beta, 高速)",
"Reset All Configs": "すべての設定をリセット",
"Cancel": "キャンセル",
"Confirm": "確認",
@@ -229,7 +229,7 @@
"Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.": "メモリが不足しています、仮想メモリ (WSL Swap) を増やすか小さなベースモデルを使用してみてください。",
"VRAM is not enough": "ビデオRAMが不足しています",
"Training data is not enough, reduce context length or add more data for training": "トレーニングデータが不足しています、コンテキストの長さを減らすか、トレーニング用のデータをさらに追加してください",
"You are using WSL 1 for training, please upgrade to WSL 2. e.g. Run \"wsl --set-version Ubuntu-22.04 2\"": "トレーニングにWSL 1を使用していますWSL 2にアップグレードしてください。例:\"wsl --set-version Ubuntu-22.04 2\"を実行する",
"Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run \"wsl --set-version Ubuntu-22.04 2\"": "Nvidia GPUが見つかりません。WindowsのGPUドライバが古すぎるか、トレーニングにWSL 1を使用している可能性がありますWSL 2にアップグレードしてください。例\"wsl --set-version Ubuntu-22.04 2\"を実行してください",
"Matched CUDA is not installed": "対応するCUDAがインストールされていません",
"Failed to convert data": "データの変換に失敗しました",
"Failed to merge model": "モデルのマージに失敗しました",
@@ -250,16 +250,75 @@
"VRAM": "VRAM",
"GPU Usage": "GPU使用率",
"Use Custom Tokenizer": "カスタムトークナイザーを使用する",
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json)": "トークナイザーパス (例: backend-python/rwkv_pip/20B_tokenizer.json)",
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json or rwkv_vocab_v20230424.txt)": "トークナイザーパス (例: backend-python/rwkv_pip/20B_tokenizer.json または rwkv_vocab_v20230424.txt)",
"User Name": "ユーザー名",
"Assistant Name": "アシスタント名",
"Insert default system prompt at the beginning": "最初にデフォルトのシステムプロンプトを挿入",
"Format Content": "内容フォーマットの規格化",
"Add An Attachment (Accepts pdf, txt)": "添付ファイルを追加 (pdf, txtを受け付けます)",
"Uploading Attachment": "添付ファイルアップロード中",
"Processing Attachment": "添付ファイルを処理中",
"Remove Attachment": "添付ファイルを削除",
"The content of file": "ファイル",
"is as follows. When replying to me, consider the file content and respond accordingly:": "の内容は以下の通りです。私に返信する際は、ファイルの内容を考慮して適切に返信してください:",
"What's the file name": "ファイル名は何ですか",
"The file name is: ": "ファイル名は次のとおりです: "
"The file name is: ": "ファイル名は次のとおりです: ",
"Port is occupied. Change it in Configs page or close the program that occupies the port.": "ポートが占有されています。設定ページで変更するか、ポートを占有しているプログラムを終了してください。",
"Loading...": "読み込み中...",
"Hello, what can I do for you?": "こんにちは、何かお手伝いできますか?",
"Enable WebUI": "WebUIを有効化",
"Server is working on deployment mode, please close the terminal window manually": "サーバーはデプロイモードで動作しています、ターミナルウィンドウを手動で閉じてください",
"Server is working on deployment mode, please exit the program manually to stop the server": "サーバーはデプロイモードで動作しています、サーバーを停止するにはプログラムを手動で終了してください",
"You can increase the number of stored layers in Configs page to improve performance": "パフォーマンスを向上させるために、保存されるレイヤーの数を設定ページで増やすことができます",
"Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.": "モデルの読み込みに失敗しました、仮想メモリ (WSL Swap) を増やすか小さなベースモデルを使用してみてください。",
"Save Conversation": "会話を保存",
"Use Hugging Face Mirror": "Hugging Faceミラーを使用",
"File is empty": "ファイルが空です",
"Open MIDI Input Audio Tracks": "MIDI入力オーディオトラックを開く",
"Track": "トラック",
"Play All": "すべて再生",
"Clear All": "すべてクリア",
"Scale View": "スケールビュー",
"Record": "録音",
"Play": "再生",
"New Track": "新規トラック",
"Select a track to preview the content": "トラックを選択して内容をプレビュー",
"Save to generation area": "生成エリアに保存",
"Piano": "ピアノ",
"Percussion": "パーカッション",
"Drum": "ドラム",
"Tuba": "チューバ",
"Marimba": "マリンバ",
"Bass": "ベース",
"Guitar": "ギター",
"Violin": "バイオリン",
"Trumpet": "トランペット",
"Sax": "サックス",
"Flute": "フルート",
"Lead": "リード",
"Pad": "パッド",
"MIDI Input": "MIDI入力",
"Select the MIDI input device to be used.": "使用するMIDI入力デバイスを選択します。",
"Start Time": "開始時間",
"Content Duration": "内容の長さ",
"Please select a MIDI device first": "まずMIDIデバイスを選択してください",
"Piano is the main instrument": "ピアノはメインの楽器です",
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。",
"This version of RWKV is not supported yet.": "このバージョンのRWKVはまだサポートされていません。",
"Main": "メイン",
"Finetuned": "微調整",
"Global": "グローバル",
"Local": "ローカル",
"CN": "中国語",
"JP": "日本語",
"Music": "音楽",
"Other": "その他",
"Import MIDI": "MIDIをインポート",
"Current Instrument": "現在の楽器",
"Please convert model to GGML format first": "モデルをGGML形式に変換してください",
"Convert To GGML Format": "GGML形式に変換",
"CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 高速)",
"Play With External Player": "外部プレーヤーで再生",
"Core API URL": "コアAPI URL",
"Override core API URL(/chat/completions and /completions). If you don't know what this is, leave it blank.": "コアAPI URLを上書きします(/chat/completions と /completions)。何であるかわからない場合は空白のままにしてください。",
"Please change Strategy to CPU (rwkv.cpp) to use ggml format": "StrategyをCPU (rwkv.cpp)に変更して、ggml形式を使用してください"
}

View File

@@ -82,7 +82,7 @@
"Just like feeding sedatives to the model. Consider the results of the top n% probability mass, 0.1 considers the top 10%, with higher quality but more conservative, 1 considers all results, with lower quality but more diverse.": "就像给模型喂镇静剂. 考虑前 n% 概率质量的结果, 0.1 考虑前 10%, 质量更高, 但更保守, 1 考虑所有质量结果, 质量降低, 但更多样",
"Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.": "存在惩罚. 正值根据新token在至今的文本中是否出现过, 来对其进行惩罚, 从而增加了模型涉及新话题的可能性",
"Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.": "频率惩罚. 正值根据新token在至今的文本中出现的频率/次数, 来对其进行惩罚, 从而减少模型原封不动地重复相同句子的可能性",
"int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality, and fp32 has the best quality.": "int8占用显存更低, 但质量略微下降. fp16质量更好, fp32质量最好",
"int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality.": "int8占用显存更低, 但质量略微下降. fp16质量更好",
"Number of the neural network layers loaded into VRAM, the more you load, the faster the speed, but it consumes more VRAM. (If your VRAM is not enough, it will fail to load)": "载入显存的神经网络层数, 载入越多, 速度越快, 但显存消耗越大 (如果你的显存不够, 会载入失败)",
"Whether to use CPU to calculate the last output layer of the neural network with FP32 precision to obtain better quality.": "是否使用cpu以fp32精度计算神经网络的最后一层输出层, 以获得更好的质量",
"Downloads": "下载",
@@ -229,7 +229,7 @@
"Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.": "内存不足,尝试增加虚拟内存(WSL Swap),或使用一个更小规模的基底模型",
"VRAM is not enough": "显存不足",
"Training data is not enough, reduce context length or add more data for training": "训练数据不足,请减小上下文长度或增加训练数据",
"You are using WSL 1 for training, please upgrade to WSL 2. e.g. Run \"wsl --set-version Ubuntu-22.04 2\"": "你正在使用WSL 1进行训练请升级到WSL 2。例如行\"wsl --set-version Ubuntu-22.04 2\"",
"Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run \"wsl --set-version Ubuntu-22.04 2\"": "没有找到Nvidia显卡。可能是因为你的windows显卡驱动太旧或者你正在使用WSL 1进行训练请升级到WSL 2。例如行\"wsl --set-version Ubuntu-22.04 2\"",
"Matched CUDA is not installed": "未安装匹配的CUDA",
"Failed to convert data": "数据转换失败",
"Failed to merge model": "合并模型失败",
@@ -250,16 +250,75 @@
"VRAM": "显存",
"GPU Usage": "GPU占用",
"Use Custom Tokenizer": "使用自定义Tokenizer",
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json)": "Tokenizer路径 (例如: backend-python/rwkv_pip/20B_tokenizer.json)",
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json or rwkv_vocab_v20230424.txt)": "Tokenizer路径 (例如: backend-python/rwkv_pip/20B_tokenizer.json 或 rwkv_vocab_v20230424.txt)",
"User Name": "用户名称",
"Assistant Name": "AI名称",
"Insert default system prompt at the beginning": "在开头自动插入默认系统提示",
"Format Content": "规范格式",
"Add An Attachment (Accepts pdf, txt)": "添加一个附件 (支持pdf, txt)",
"Uploading Attachment": "正在上传附件",
"Processing Attachment": "正在处理附件",
"Remove Attachment": "移除附件",
"The content of file": "文件",
"is as follows. When replying to me, consider the file content and respond accordingly:": "内容如下。回复时考虑文件内容并做出相应回复:",
"What's the file name": "文件名是什么",
"The file name is: ": "文件名是:"
"The file name is: ": "文件名是:",
"Port is occupied. Change it in Configs page or close the program that occupies the port.": "端口被占用。请在配置页面更改端口,或关闭占用端口的程序",
"Loading...": "加载中...",
"Hello, what can I do for you?": "你好,有什么要我帮忙的吗?",
"Enable WebUI": "启用WebUI",
"Server is working on deployment mode, please close the terminal window manually": "服务器正在部署模式下运行,请手动关闭终端窗口",
"Server is working on deployment mode, please exit the program manually to stop the server": "服务器正在部署模式下运行,请手动退出程序以停止服务器",
"You can increase the number of stored layers in Configs page to improve performance": "你可以在配置页面增加载入显存层数以提升性能",
"Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.": "模型载入失败,尝试增加虚拟内存(WSL Swap),或使用一个更小规模的基底模型",
"Save Conversation": "保存对话",
"Use Hugging Face Mirror": "使用Hugging Face镜像源",
"File is empty": "文件为空",
"Open MIDI Input Audio Tracks": "打开MIDI输入音轨",
"Track": "音轨",
"Play All": "播放全部",
"Clear All": "清空",
"Scale View": "缩放视图",
"Record": "录制",
"Play": "播放",
"New Track": "新建音轨",
"Select a track to preview the content": "选择一个音轨以预览内容",
"Save to generation area": "保存到生成区",
"Piano": "钢琴",
"Percussion": "打击乐",
"Drum": "鼓",
"Tuba": "大号",
"Marimba": "马林巴",
"Bass": "贝斯",
"Guitar": "吉他",
"Violin": "小提琴",
"Trumpet": "小号",
"Sax": "萨克斯",
"Flute": "长笛",
"Lead": "主音",
"Pad": "和音",
"MIDI Input": "MIDI输入",
"Select the MIDI input device to be used.": "选择要使用的MIDI输入设备",
"Start Time": "开始时间",
"Content Duration": "内容时长",
"Please select a MIDI device first": "请先选择一个MIDI设备",
"Piano is the main instrument": "钢琴为主",
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高请检查训练数据并确保你的显卡驱动是最新的",
"This version of RWKV is not supported yet.": "暂不支持此版本的RWKV",
"Main": "主干",
"Finetuned": "微调",
"Global": "全球",
"Local": "本地",
"CN": "中文",
"JP": "日文",
"Music": "音乐",
"Other": "其他",
"Import MIDI": "导入MIDI",
"Current Instrument": "当前乐器",
"Please convert model to GGML format first": "请先将模型转换为GGML格式",
"Convert To GGML Format": "转换为GGML格式",
"CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 更快)",
"Play With External Player": "使用外部播放器播放",
"Core API URL": "核心 API URL",
"Override core API URL(/chat/completions and /completions). If you don't know what this is, leave it blank.": "覆盖核心的 API URL (/chat/completions 和 /completions)。如果你不知道这是什么,请留空",
"Please change Strategy to CPU (rwkv.cpp) to use ggml format": "请将Strategy改为CPU (rwkv.cpp)以使用ggml格式"
}

View File

@@ -1,10 +1,10 @@
import { FC } from 'react';
import { observer } from 'mobx-react-lite';
import { Dropdown, Option } from '@fluentui/react-components';
import { Dropdown, Option, PresenceBadge } from '@fluentui/react-components';
import commonStore from '../stores/commonStore';
export const ConfigSelector: FC<{ size?: 'small' | 'medium' | 'large' }> = observer(({ size }) => {
return <Dropdown size={size} style={{ minWidth: 0 }} listbox={{ style: { minWidth: 0 } }}
return <Dropdown size={size} style={{ minWidth: 0 }} listbox={{ style: { minWidth: 'fit-content' } }}
value={commonStore.getCurrentModelConfig().name}
selectedOptions={[commonStore.currentModelConfigIndex.toString()]}
onOptionSelect={(_, data) => {
@@ -12,7 +12,13 @@ export const ConfigSelector: FC<{ size?: 'small' | 'medium' | 'large' }> = obser
commonStore.setCurrentConfigIndex(Number(data.optionValue));
}}>
{commonStore.modelConfigs.map((config, index) =>
<Option key={index} value={index.toString()}>{config.name}</Option>
<Option key={index} value={index.toString()} text={config.name}>
<div className="flex justify-between grow">
{config.name}
{commonStore.modelSourceList.find(item => item.name === config.modelParameters.modelName)?.isComplete
&& <PresenceBadge status="available" />}
</div>
</Option>
)}
</Dropdown>;
});

View File

@@ -1,4 +1,4 @@
import { FC, ReactElement } from 'react';
import React, { FC, ReactElement } from 'react';
import {
Button,
Dialog,
@@ -11,7 +11,9 @@ import {
} from '@fluentui/react-components';
import { ToolTipButton } from './ToolTipButton';
import { useTranslation } from 'react-i18next';
import MarkdownRender from './MarkdownRender';
import { LazyImportComponent } from './LazyImportComponent';
const MarkdownRender = React.lazy(() => import('./MarkdownRender'));
export const DialogButton: FC<{
text?: string | null
@@ -45,7 +47,9 @@ export const DialogButton: FC<{
<DialogContent>
{
markdown ?
<MarkdownRender>{contentText}</MarkdownRender> :
<LazyImportComponent lazyChildren={MarkdownRender}>
{contentText}
</LazyImportComponent> :
contentText
}
</DialogContent>

View File

@@ -0,0 +1,24 @@
import { FC, LazyExoticComponent, ReactNode, Suspense } from 'react';
import { useTranslation } from 'react-i18next';
import { Spinner } from '@fluentui/react-components';
interface LazyImportComponentProps {
lazyChildren: LazyExoticComponent<FC<any>>;
lazyProps?: any;
children?: ReactNode;
}
export const LazyImportComponent: FC<LazyImportComponentProps> = (props) => {
const { t } = useTranslation();
return (
<Suspense fallback={
<div className="flex justify-center items-center h-full w-full">
<Spinner size="huge" label={t('Loading...')} />
</div>}>
<props.lazyChildren {...props.lazyProps}>
{props.children}
</props.lazyChildren>
</Suspense>
);
};

View File

@@ -21,7 +21,7 @@ const Hyperlink: FC<any> = ({ href, children }) => {
);
};
export const MarkdownRender: FC<ReactMarkdownOptions> = (props) => {
const MarkdownRender: FC<ReactMarkdownOptions> = (props) => {
return (
<div dir="auto" className="markdown-body">
<ReactMarkdown

View File

@@ -40,6 +40,8 @@ export const ReadButton: FC<{
voice = voices.find((v) => v.name.toLowerCase().includes('microsoft aria'));
else if (lang === 'zh')
voice = voices.find((v) => v.name.toLowerCase().includes('xiaoyi'));
else if (lang === 'ja')
voice = voices.find((v) => v.name.toLowerCase().includes('nanami'));
if (!voice) voice = voices.find((v) => v.lang.substring(0, 2) === lang);
if (!voice) voice = voices.find((v) => v.lang === navigator.language);

View File

@@ -1,16 +1,24 @@
import React, { FC, MouseEventHandler, ReactElement } from 'react';
import commonStore, { ModelStatus } from '../stores/commonStore';
import { AddToDownloadList, FileExists, StartServer, StartWebGPUServer } from '../../wailsjs/go/backend_golang/App';
import {
AddToDownloadList,
FileExists,
IsPortAvailable,
StartServer,
StartWebGPUServer
} from '../../wailsjs/go/backend_golang/App';
import { Button } from '@fluentui/react-components';
import { observer } from 'mobx-react-lite';
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
import { toast } from 'react-toastify';
import { checkDependencies, getStrategy, toastWithButton } from '../utils';
import { checkDependencies, getHfDownloadUrl, getStrategy, toastWithButton } from '../utils';
import { useTranslation } from 'react-i18next';
import { ToolTipButton } from './ToolTipButton';
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
import { useNavigate } from 'react-router';
import { WindowShow } from '../../wailsjs/runtime/runtime';
import { WindowShow } from '../../wailsjs/runtime';
import { convertToGGML, convertToSt } from '../utils/convert-model';
import { Precision } from '../types/configs';
const mainButtonText = {
[ModelStatus.Offline]: 'Run',
@@ -40,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const modelConfig = commonStore.getCurrentModelConfig();
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)';
let modelName = '';
let modelPath = '';
if (modelConfig && modelConfig.modelParameters) {
@@ -51,13 +60,40 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
const showDownloadPrompt = (promptInfo: string, downloadName: string) => {
toastWithButton(promptInfo, t('Download'), () => {
const downloadUrl = currentModelSource?.downloadUrl;
if (downloadUrl) {
toastWithButton(`${t('Downloading')} ${downloadName}`, t('Check'), () => {
navigate({ pathname: '/downloads' });
},
{ autoClose: 3000 });
AddToDownloadList(modelPath, getHfDownloadUrl(downloadUrl));
} else {
toast(t('Can not find download url'), { type: 'error' });
}
});
};
if (webgpu) {
if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
const stModelPath = modelPath.replace(/\.pth$/, '.st');
if (await FileExists(stModelPath)) {
modelPath = stModelPath;
} else if (!await FileExists(modelPath)) {
showDownloadPrompt(t('Model file not found'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
} else if (!currentModelSource?.isComplete) {
showDownloadPrompt(t('Model file download is not complete'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
} else {
toast(t('Please convert model to safe tensors format first'), { type: 'error' });
toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => {
convertToSt(modelConfig);
});
commonStore.setStatus({ status: ModelStatus.Offline });
return;
}
@@ -78,28 +114,45 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
const showDownloadPrompt = (promptInfo: string, downloadName: string) => {
toastWithButton(promptInfo, t('Download'), () => {
const downloadUrl = currentModelSource?.downloadUrl;
if (downloadUrl) {
toastWithButton(`${t('Downloading')} ${downloadName}`, t('Check'), () => {
navigate({ pathname: '/downloads' });
},
{ autoClose: 3000 });
AddToDownloadList(modelPath, downloadUrl);
if (cpp) {
if (!['.bin'].some(ext => modelPath.endsWith(ext))) {
const precision: Precision = modelConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
const ggmlModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
if (await FileExists(ggmlModelPath)) {
modelPath = ggmlModelPath;
} else if (!await FileExists(modelPath)) {
showDownloadPrompt(t('Model file not found'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
} else if (!currentModelSource?.isComplete) {
showDownloadPrompt(t('Model file download is not complete'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
} else {
toast(t('Can not find download url'), { type: 'error' });
toastWithButton(t('Please convert model to GGML format first'), t('Convert'), () => {
convertToGGML(modelConfig, navigate);
});
commonStore.setStatus({ status: ModelStatus.Offline });
return;
}
});
};
}
}
if (!cpp) {
if (['.bin'].some(ext => modelPath.endsWith(ext))) {
toast(t('Please change Strategy to CPU (rwkv.cpp) to use ggml format'), { type: 'error' });
commonStore.setStatus({ status: ModelStatus.Offline });
return;
}
}
if (!await FileExists(modelPath)) {
showDownloadPrompt(t('Model file not found'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
} else if (!currentModelSource?.isComplete) {
} else // If the user selects the .pth model with WebGPU mode, modelPath will be set to the .st model.
// However, if the .pth model is deleted, modelPath will exist and isComplete will be false.
if (!currentModelSource?.isComplete && modelPath.endsWith('.pth')) {
showDownloadPrompt(t('Model file download is not complete'), modelName);
commonStore.setStatus({ status: ModelStatus.Offline });
return;
@@ -107,8 +160,15 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const port = modelConfig.apiParameters.apiPort;
await exit(1000).catch(() => {
});
if (!await IsPortAvailable(port)) {
await exit(1000).catch(() => {
});
if (!await IsPortAvailable(port)) {
toast(t('Port is occupied. Change it in Configs page or close the program that occupies the port.'), { type: 'error' });
commonStore.setStatus({ status: ModelStatus.Offline });
return;
}
}
const startServer = webgpu ?
(_: string, port: number, host: string) => StartWebGPUServer(port, host)
@@ -116,7 +176,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta';
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
isUsingCudaBeta
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp
).catch((e) => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
@@ -125,6 +185,8 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
toast(t('Error') + ' - ' + errMsg, { type: 'error' });
});
setTimeout(WindowShow, 1000);
setTimeout(WindowShow, 2000);
setTimeout(WindowShow, 3000);
let timeoutCount = 6;
let loading = false;
@@ -141,7 +203,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
});
}
commonStore.setStatus({ status: ModelStatus.Loading });
toast(t('Loading Model'), { type: 'info' });
const loadingId = toast(t('Loading Model'), { type: 'info', autoClose: false });
if (!webgpu) {
updateConfig({
max_tokens: modelConfig.apiParameters.maxResponseToken,
@@ -155,7 +217,8 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const strategy = getStrategy(modelConfig);
let customCudaFile = '';
if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom')
&& modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
&& modelConfig.modelParameters.useCustomCuda
&& !strategy.split('->').some(s => ['cuda', 'fp32'].every(v => s.includes(v)))) {
if (commonStore.platform === 'windows') {
// this part is currently unused because there's no longer a need to use different kernels for different GPUs, but it might still be needed in the future
//
@@ -186,7 +249,8 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
model: modelPath,
strategy: strategy,
tokenizer: modelConfig.modelParameters.useCustomTokenizer ? modelConfig.modelParameters.customTokenizer : undefined,
customCuda: customCudaFile !== ''
customCuda: customCudaFile !== '',
deploy: modelConfig.enableWebUI
}).then(async (r) => {
if (r.ok) {
commonStore.setStatus({ status: ModelStatus.Working });
@@ -199,6 +263,12 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const buttonFn = () => {
navigate({ pathname: '/' + buttonName.toLowerCase() });
};
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'CUDA-Beta') &&
modelConfig.modelParameters.storedLayers < modelConfig.modelParameters.maxStoredLayers &&
commonStore.monitorData && commonStore.monitorData.totalVram !== 0 &&
(commonStore.monitorData.usedVram / commonStore.monitorData.totalVram) < 0.9)
toast(t('You can increase the number of stored layers in Configs page to improve performance'), { type: 'info' });
toastWithButton(t('Startup Completed'), t(buttonName), buttonFn, { type: 'success', autoClose: 3000 });
} else if (r.status === 304) {
toast(t('Loading Model'), { type: 'info' });
@@ -211,7 +281,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
'invalid header or archive is corrupted': 'The model file is corrupted, please download again.',
'no NVIDIA driver': 'Found no NVIDIA driver, please install the latest driver.',
'CUDA out of memory': 'VRAM is not enough, please reduce stored layers or use a lower precision in Configs page.',
'Ninja is required to load C++ extensions': 'Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.',
'Ninja is required to load C++ extensions': 'Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.'
};
const matchedError = Object.entries(errorsMap).find(([key, _]) => error.includes(key));
const message = matchedError ? t(matchedError[1]) : error;
@@ -220,6 +290,8 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
}).catch((e) => {
commonStore.setStatus({ status: ModelStatus.Offline });
toast(t('Failed to switch model') + ' - ' + (e.message || e), { type: 'error' });
}).finally(() => {
toast.dismiss(loadingId);
});
}
}).catch(() => {
@@ -233,7 +305,13 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
}, 1000);
} else {
commonStore.setStatus({ status: ModelStatus.Offline });
exit();
exit().then(r => {
if (r.status === 403)
if (commonStore.platform !== 'linux')
toast(t('Server is working on deployment mode, please close the terminal window manually'), { type: 'info' });
else
toast(t('Server is working on deployment mode, please exit the program manually to stop the server'), { type: 'info' });
});
}
};

View File

@@ -6,6 +6,7 @@ import { ConfigSelector } from './ConfigSelector';
import { RunButton } from './RunButton';
import { PresenceBadgeStatus } from '@fluentui/react-badge';
import { useTranslation } from 'react-i18next';
import { useMediaQuery } from 'usehooks-ts';
const statusText = {
[ModelStatus.Offline]: 'Offline',
@@ -23,15 +24,21 @@ const badgeStatus: { [modelStatus: number]: PresenceBadgeStatus } = {
export const WorkHeader: FC = observer(() => {
const { t } = useTranslation();
const mq = useMediaQuery('(min-width: 640px)');
const port = commonStore.getCurrentModelConfig().apiParameters.apiPort;
return (
return commonStore.platform === 'web' ?
<div /> :
<div className="flex flex-col gap-1">
<div className="flex justify-between items-center">
<div className="flex items-center gap-2">
<PresenceBadge status={badgeStatus[commonStore.status.status]} />
<Text size={100}>{t('Model Status') + ': ' + t(statusText[commonStore.status.status])}</Text>
</div>
{commonStore.lastModelName && mq &&
<Text size={100}>
{commonStore.lastModelName}
</Text>}
<div className="flex items-center gap-2">
<ConfigSelector size="small" />
<RunButton iconMode />
@@ -42,5 +49,5 @@ export const WorkHeader: FC = observer(() => {
</Text>
<Divider style={{ flexGrow: 0 }} />
</div>
);
;
});

View File

@@ -1,3 +1,4 @@
import './webWails';
import React from 'react';
import { createRoot } from 'react-dom/client';
import './style.scss';
@@ -6,7 +7,6 @@ import App from './App';
import { HashRouter } from 'react-router-dom';
import { startup } from './startup';
import './_locales/i18n-react';
import 'html-midi-player';
import { WindowShow } from '../wailsjs/runtime';
startup().then(() => {

View File

@@ -5,9 +5,7 @@ import MarkdownRender from '../components/MarkdownRender';
import { observer } from 'mobx-react-lite';
import commonStore from '../stores/commonStore';
export type AboutContent = { [lang: string]: string }
export const About: FC = observer(() => {
const About: FC = observer(() => {
const { t } = useTranslation();
const lang: string = commonStore.settings.language;
@@ -21,3 +19,5 @@ export const About: FC = observer(() => {
} />
);
});
export default About;

View File

@@ -0,0 +1,40 @@
import React, { FC, lazy } from 'react';
import { useTranslation } from 'react-i18next';
import { Button, Dialog, DialogBody, DialogContent, DialogSurface, DialogTrigger } from '@fluentui/react-components';
import { CustomToastContainer } from '../../components/CustomToastContainer';
import { LazyImportComponent } from '../../components/LazyImportComponent';
import { flushMidiRecordingContent } from '../../utils';
import commonStore from '../../stores/commonStore';
const AudiotrackEditor = lazy(() => import('./AudiotrackEditor'));
export const AudiotrackButton: FC<{
size?: 'small' | 'medium' | 'large',
shape?: 'rounded' | 'circular' | 'square';
appearance?: 'secondary' | 'primary' | 'outline' | 'subtle' | 'transparent';
setPrompt: (prompt: string) => void;
}> = ({ size, shape, appearance, setPrompt }) => {
const { t } = useTranslation();
return <Dialog onOpenChange={(e, data) => {
if (!data.open) {
flushMidiRecordingContent();
commonStore.setRecordingTrackId('');
commonStore.setPlayingTrackId('');
}
}}>
<DialogTrigger disableButtonEnhancement>
<Button size={size} shape={shape} appearance={appearance}>
{t('Open MIDI Input Audio Tracks')}
</Button>
</DialogTrigger>
<DialogSurface style={{ paddingTop: 0, maxWidth: '90vw', width: 'fit-content' }}>
<DialogBody>
<DialogContent className="overflow-hidden">
<CustomToastContainer />
<LazyImportComponent lazyChildren={AudiotrackEditor} lazyProps={{ setPrompt }} />
</DialogContent>
</DialogBody>
</DialogSurface>
</Dialog>;
};

View File

@@ -0,0 +1,601 @@
import React, { FC, useEffect, useRef, useState } from 'react';
import { observer } from 'mobx-react-lite';
import { useTranslation } from 'react-i18next';
import Draggable from 'react-draggable';
import { ToolTipButton } from '../../components/ToolTipButton';
import { v4 as uuid } from 'uuid';
import {
Add16Regular,
ArrowAutofitWidth20Regular,
ArrowUpload16Regular,
Delete16Regular,
MusicNote220Regular,
Pause16Regular,
Play16Filled,
Play16Regular,
Record16Regular,
Stop16Filled
} from '@fluentui/react-icons';
import { Button, Card, DialogTrigger, Slider, Text, Tooltip } from '@fluentui/react-components';
import { useWindowSize } from 'usehooks-ts';
import commonStore, { ModelStatus } from '../../stores/commonStore';
import classnames from 'classnames';
import {
InstrumentType,
InstrumentTypeNameMap,
InstrumentTypeTokenMap,
MidiMessage,
tracksMinimalTotalTime
} from '../../types/composition';
import { toast } from 'react-toastify';
import {
absPathAsset,
flushMidiRecordingContent,
getMidiRawContentMainInstrument,
getMidiRawContentTime,
getServerRoot,
refreshTracksTotalTime
} from '../../utils';
import { OpenOpenFileDialog, PlayNote } from '../../../wailsjs/go/backend_golang/App';
const snapValue = 25;
const minimalMoveTime = 8; // 1000/125=8ms wait_events=125
const scaleMin = 0.05;
const scaleMax = 3;
const baseMoveTime = Math.round(minimalMoveTime / scaleMin);
const velocityEvents = 128;
const velocityBins = 12;
const velocityExp = 0.5;
const minimalTrackWidth = 80;
const trackInitOffsetPx = 10;
const pixelFix = 0.5;
const topToArrowIcon = 19;
const arrowIconToTracks = 23;
const velocityToBin = (velocity: number) => {
velocity = Math.max(0, Math.min(velocity, velocityEvents - 1));
const binsize = velocityEvents / (velocityBins - 1);
return Math.ceil((velocityEvents * ((Math.pow(velocityExp, (velocity / velocityEvents)) - 1.0) / (velocityExp - 1.0))) / binsize);
};
const binToVelocity = (bin: number) => {
const binsize = velocityEvents / (velocityBins - 1);
return Math.max(0, Math.ceil(velocityEvents * (Math.log(((velocityExp - 1) * binsize * bin) / velocityEvents + 1) / Math.log(velocityExp)) - 1));
};
const tokenToMidiMessage = (token: string): MidiMessage | null => {
if (token.startsWith('<')) return null;
if (token.startsWith('t') && !token.startsWith('t:')) {
return {
messageType: 'ElapsedTime',
value: parseInt(token.substring(1)) * minimalMoveTime,
channel: 0,
note: 0,
velocity: 0,
control: 0,
instrument: 0
};
}
const instrument: InstrumentType = InstrumentTypeTokenMap.findIndex(t => token.startsWith(t + ':'));
if (instrument >= 0) {
const parts = token.split(':');
if (parts.length !== 3) return null;
const note = parseInt(parts[1], 16);
const velocity = parseInt(parts[2], 16);
if (velocity < 0 || velocity > 127) return null;
if (velocity === 0) return {
messageType: 'NoteOff',
note: note,
instrument: instrument,
channel: 0,
velocity: 0,
control: 0,
value: 0
};
return {
messageType: 'NoteOn',
note: note,
velocity: binToVelocity(velocity),
instrument: instrument,
channel: 0,
control: 0,
value: 0
} as MidiMessage;
}
return null;
};
const midiMessageToToken = (msg: MidiMessage) => {
if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
const instrument = InstrumentTypeTokenMap[msg.instrument];
const note = msg.note.toString(16);
const velocity = velocityToBin(msg.velocity).toString(16);
return `${instrument}:${note}:${velocity} `;
} else if (msg.messageType === 'ElapsedTime') {
let time = Math.round(msg.value / minimalMoveTime);
const num = Math.floor(time / 125); // wait_events=125
time -= num * 125;
let ret = '';
for (let i = 0; i < num; i++) {
ret += 't125 ';
}
if (time > 0)
ret += `t${time} `;
return ret;
} else
return '';
};
let dropRecordingTime = false;
export const midiMessageHandler = async (data: MidiMessage) => {
if (data.messageType === 'ControlChange') {
commonStore.setInstrumentType(Math.round(data.value / 127 * (InstrumentTypeNameMap.length - 1)));
return;
}
if (commonStore.recordingTrackId) {
if (dropRecordingTime && data.messageType === 'ElapsedTime') {
dropRecordingTime = false;
return;
}
data = {
...data,
instrument: commonStore.instrumentType
};
commonStore.setRecordingRawContent([...commonStore.recordingRawContent, data]);
commonStore.setRecordingContent(commonStore.recordingContent + midiMessageToToken(data));
//TODO data.channel = data.instrument;
PlayNote(data);
}
};
type TrackProps = {
id: string;
right: number;
scale: number;
isSelected: boolean;
onSelect: (id: string) => void;
};
const Track: React.FC<TrackProps> = observer(({
id,
right,
scale,
isSelected,
onSelect
}) => {
const { t } = useTranslation();
const trackIndex = commonStore.tracks.findIndex(t => t.id === id)!;
const track = commonStore.tracks[trackIndex];
const trackClass = isSelected ? 'bg-blue-600' : (commonStore.settings.darkMode ? 'bg-blue-900' : 'bg-gray-700');
const controlX = useRef(0);
let trackName = t('Track') + ' ' + id;
if (track.mainInstrument)
trackName = t('Track') + ' - ' + t('Piano is the main instrument')!.replace(t('Piano')!, t(track.mainInstrument)) + (track.content && (' - ' + track.content));
else if (track.content)
trackName = t('Track') + ' - ' + track.content;
return (
<Draggable
axis="x"
bounds={{ left: 0, right }}
grid={[snapValue, snapValue]}
position={{
x: (track.offsetTime - commonStore.trackCurrentTime) / (baseMoveTime * scale) * snapValue,
y: 0
}}
onStart={(e, data) => {
controlX.current = data.lastX;
}}
onStop={(e, data) => {
const delta = data.lastX - controlX.current;
let offsetTime = Math.round(Math.round(delta / snapValue * baseMoveTime * scale) / minimalMoveTime) * minimalMoveTime;
offsetTime = Math.min(Math.max(
offsetTime,
-track.offsetTime), commonStore.trackTotalTime - track.offsetTime);
const tracks = commonStore.tracks.slice();
tracks[trackIndex].offsetTime += offsetTime;
commonStore.setTracks(tracks);
refreshTracksTotalTime();
}}
>
<div
className={`p-1 cursor-move rounded whitespace-nowrap overflow-hidden ${trackClass}`}
style={{
width: `${Math.max(minimalTrackWidth,
track.contentTime / (baseMoveTime * scale) * snapValue
)}px`
}}
onClick={() => onSelect(id)}
>
<span className="text-white">{trackName}</span>
</div>
</Draggable>
);
});
const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer(({ setPrompt }) => {
const { t } = useTranslation();
const viewControlsContainerRef = useRef<HTMLDivElement>(null);
const currentTimeControlRef = useRef<HTMLDivElement>(null);
const playStartTimeControlRef = useRef<HTMLDivElement>(null);
const tracksEndLineRef = useRef<HTMLDivElement>(null);
const tracksRef = useRef<HTMLDivElement>(null);
const toolbarRef = useRef<HTMLDivElement>(null);
const toolbarButtonRef = useRef<HTMLDivElement>(null);
const toolbarSliderRef = useRef<HTMLInputElement>(null);
const contentPreviewRef = useRef<HTMLDivElement>(null);
const [refreshRef, setRefreshRef] = useState(false);
const windowSize = useWindowSize();
const scale = (scaleMin + scaleMax) - commonStore.trackScale;
const [selectedTrackId, setSelectedTrackId] = useState<string>('');
const playStartTimeControlX = useRef(0);
const selectedTrack = selectedTrackId ? commonStore.tracks.find(t => t.id === selectedTrackId) : undefined;
useEffect(() => {
if (toolbarSliderRef.current && toolbarSliderRef.current.parentElement)
toolbarSliderRef.current.parentElement.style.removeProperty('--fui-Slider--steps-percent');
}, []);
const scrollContentToBottom = () => {
if (contentPreviewRef.current)
contentPreviewRef.current.scrollTop = contentPreviewRef.current.scrollHeight;
};
useEffect(() => {
scrollContentToBottom();
}, [commonStore.recordingContent]);
useEffect(() => {
setRefreshRef(!refreshRef);
}, [windowSize, commonStore.tracks]);
const viewControlsContainerWidth = (toolbarRef.current && toolbarButtonRef.current && toolbarSliderRef.current) ?
toolbarRef.current.clientWidth - toolbarButtonRef.current.clientWidth - toolbarSliderRef.current.clientWidth - 16 // 16 = ml-2 mr-2
: 0;
const tracksWidth = viewControlsContainerWidth;
const timeOfTracksWidth = Math.floor(tracksWidth / snapValue) // number of moves
* baseMoveTime * scale;
const currentTimeControlWidth = (timeOfTracksWidth < commonStore.trackTotalTime)
? timeOfTracksWidth / commonStore.trackTotalTime * viewControlsContainerWidth
: 0;
const playStartTimeControlPosition = (commonStore.trackPlayStartTime - commonStore.trackCurrentTime) / (baseMoveTime * scale) * snapValue;
const tracksEndPosition = (commonStore.trackTotalTime - commonStore.trackCurrentTime) / (baseMoveTime * scale) * snapValue;
const moveableTracksWidth = (tracksEndLineRef.current && viewControlsContainerRef.current &&
((tracksEndLineRef.current.getBoundingClientRect().left - (viewControlsContainerRef.current.getBoundingClientRect().left + trackInitOffsetPx)) > 0))
? tracksEndLineRef.current.getBoundingClientRect().left - (viewControlsContainerRef.current.getBoundingClientRect().left + trackInitOffsetPx)
: Infinity;
return (
<div className="flex flex-col gap-2 overflow-hidden" style={{ width: '80vw', height: '80vh' }}>
<div className="mx-auto">
<Text size={100}>{`${commonStore.trackPlayStartTime} ms / ${commonStore.trackTotalTime} ms`}</Text>
</div>
<div className="flex pb-2 border-b" ref={toolbarRef}>
<div className="flex gap-2" ref={toolbarButtonRef}>
<ToolTipButton disabled desc={t('Play All') + ' (Unavailable)'} icon={<Play16Regular />} />
<ToolTipButton desc={t('Clear All')} icon={<Delete16Regular />} onClick={() => {
commonStore.setTracks([]);
commonStore.setTrackScale(1);
commonStore.setTrackTotalTime(tracksMinimalTotalTime);
commonStore.setTrackCurrentTime(0);
commonStore.setTrackPlayStartTime(0);
}} />
</div>
<div className="grow">
<div className="flex flex-col ml-2 mr-2" ref={viewControlsContainerRef}>
<div className="relative">
<Tooltip content={`${commonStore.trackTotalTime} ms`} showDelay={0} hideDelay={0}
relationship="description">
<div className="border-l absolute"
ref={tracksEndLineRef}
style={{
height: (tracksRef.current && commonStore.tracks.length > 0)
? tracksRef.current.clientHeight - arrowIconToTracks
: 0,
top: `${topToArrowIcon + arrowIconToTracks}px`,
left: `${tracksEndPosition + trackInitOffsetPx - pixelFix}px`
}} />
</Tooltip>
</div>
<Draggable axis="x" bounds={{
left: 0,
right: viewControlsContainerWidth - currentTimeControlWidth
}}
position={{
x: commonStore.trackCurrentTime / commonStore.trackTotalTime * viewControlsContainerWidth,
y: 0
}}
onDrag={(e, data) => {
setTimeout(() => {
let offset = 0;
if (currentTimeControlRef.current) {
const match = currentTimeControlRef.current.style.transform.match(/translate\((.+)px,/);
if (match)
offset = parseFloat(match[1]);
}
const offsetTime = commonStore.trackTotalTime / viewControlsContainerWidth * offset;
commonStore.setTrackCurrentTime(offsetTime);
}, 1);
}}
>
<div ref={currentTimeControlRef}
className={classnames('h-2 cursor-move rounded', commonStore.settings.darkMode ? 'bg-neutral-600' : 'bg-gray-700')}
style={{ width: currentTimeControlWidth }} />
</Draggable>
<div className={classnames(
'flex',
(playStartTimeControlPosition < 0 || playStartTimeControlPosition > viewControlsContainerWidth)
&& 'hidden'
)}>
<Draggable axis="x" bounds={{
left: 0,
right: (playStartTimeControlRef.current)
? Math.min(viewControlsContainerWidth - playStartTimeControlRef.current.clientWidth, moveableTracksWidth)
: 0
}}
grid={[snapValue, snapValue]}
position={{ x: playStartTimeControlPosition, y: 0 }}
onStart={(e, data) => {
playStartTimeControlX.current = data.lastX;
}}
onStop={(e, data) => {
const delta = data.lastX - playStartTimeControlX.current;
let offsetTime = Math.round(Math.round(delta / snapValue * baseMoveTime * scale) / minimalMoveTime) * minimalMoveTime;
offsetTime = Math.min(Math.max(
offsetTime,
-commonStore.trackPlayStartTime), commonStore.trackTotalTime - commonStore.trackPlayStartTime);
commonStore.setTrackPlayStartTime(commonStore.trackPlayStartTime + offsetTime);
}}
>
<div className="relative cursor-move"
ref={playStartTimeControlRef}>
<ArrowAutofitWidth20Regular />
<div
className={classnames('border-l absolute', commonStore.settings.darkMode ? 'border-white' : 'border-gray-700')}
style={{
height: (tracksRef.current && commonStore.tracks.length > 0)
? tracksRef.current.clientHeight
: 0,
top: '50%',
left: `calc(50% - ${pixelFix}px)`
}} />
</div>
</Draggable>
</div>
</div>
</div>
<Tooltip content={t('Scale View')! + ': ' + commonStore.trackScale} showDelay={0} hideDelay={0}
relationship="description">
<Slider ref={toolbarSliderRef} value={commonStore.trackScale} step={scaleMin} max={scaleMax} min={scaleMin}
onChange={(e, data) => {
commonStore.setTrackScale(data.value);
}}
/>
</Tooltip>
</div>
<div className="flex flex-col overflow-y-auto gap-1" ref={tracksRef}>
{commonStore.tracks.map(track =>
<div key={track.id} className="flex gap-2 pb-1 border-b">
<div className="flex gap-1 border-r h-7">
<ToolTipButton desc={commonStore.recordingTrackId === track.id ? t('Stop') : t('Record')}
disabled={commonStore.platform === 'web'}
icon={commonStore.recordingTrackId === track.id ? <Stop16Filled /> : <Record16Regular />}
size="small" shape="circular" appearance="subtle"
onClick={() => {
flushMidiRecordingContent();
commonStore.setPlayingTrackId('');
if (commonStore.recordingTrackId === track.id) {
commonStore.setRecordingTrackId('');
} else {
if (commonStore.activeMidiDeviceIndex === -1) {
toast(t('Please select a MIDI device first'), { type: 'warning' });
return;
}
dropRecordingTime = true;
setSelectedTrackId(track.id);
commonStore.setRecordingTrackId(track.id);
commonStore.setRecordingContent(track.content);
commonStore.setRecordingRawContent(track.rawContent.slice());
}
}} />
<ToolTipButton disabled
desc={commonStore.playingTrackId === track.id ? t('Stop') : t('Play') + ' (Unavailable)'}
icon={commonStore.playingTrackId === track.id ? <Pause16Regular /> : <Play16Filled />}
size="small" shape="circular" appearance="subtle"
onClick={() => {
flushMidiRecordingContent();
commonStore.setRecordingTrackId('');
if (commonStore.playingTrackId === track.id) {
commonStore.setPlayingTrackId('');
} else {
setSelectedTrackId(track.id);
commonStore.setPlayingTrackId(track.id);
}
}} />
<ToolTipButton desc={t('Delete')} icon={<Delete16Regular />} size="small" shape="circular"
appearance="subtle" onClick={() => {
const tracks = commonStore.tracks.slice().filter(t => t.id !== track.id);
commonStore.setTracks(tracks);
refreshTracksTotalTime();
}} />
</div>
<div className="relative grow overflow-hidden">
<div className="absolute" style={{ left: -0 }}>
<Track
id={track.id}
scale={scale}
right={Math.min(tracksWidth, moveableTracksWidth)}
isSelected={selectedTrackId === track.id}
onSelect={setSelectedTrackId}
/>
</div>
</div>
</div>)}
<div className="flex justify-between items-center">
<div className="flex gap-1">
<Button icon={<Add16Regular />} size="small" shape="circular"
appearance="subtle"
disabled={commonStore.platform === 'web'}
onClick={() => {
commonStore.setTracks([...commonStore.tracks, {
id: uuid(),
mainInstrument: '',
content: '',
rawContent: [],
offsetTime: 0,
contentTime: 0
}]);
}}>
{t('New Track')}
</Button>
<Button icon={<ArrowUpload16Regular />} size="small" shape="circular"
appearance="subtle"
onClick={() => {
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl && commonStore.platform !== 'web') {
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
return;
}
OpenOpenFileDialog('*.mid').then(async filePath => {
if (!filePath)
return;
let blob: Blob;
if (commonStore.platform === 'web')
blob = (filePath as unknown as { blob: Blob }).blob;
else
blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
const bodyForm = new FormData();
bodyForm.append('file_data', blob);
fetch(getServerRoot(commonStore.getCurrentModelConfig().apiParameters.apiPort) + '/midi-to-text', {
method: 'POST',
body: bodyForm
}).then(async r => {
if (r.status === 200) {
const text = (await r.json()).text as string;
const rawContent = text.split(' ').map(tokenToMidiMessage).filter(m => m) as MidiMessage[];
const tracks = commonStore.tracks.slice();
tracks.push({
id: uuid(),
mainInstrument: getMidiRawContentMainInstrument(rawContent),
content: text,
rawContent: rawContent,
offsetTime: 0,
contentTime: getMidiRawContentTime(rawContent)
});
commonStore.setTracks(tracks);
refreshTracksTotalTime();
} else {
toast(r.statusText + '\n' + (await r.text()), {
type: 'error'
});
}
}
).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}}>
{t('Import MIDI')}
</Button>
</div>
<Text size={100}>
{t('Select a track to preview the content')}
</Text>
</div>
</div>
<div className="grow"></div>
{selectedTrack &&
<Card size="small" appearance="outline" style={{ minHeight: '150px', maxHeight: '200px' }}>
<div className="flex flex-col gap-1 overflow-hidden">
<Text size={100}>{`${t('Start Time')}: ${selectedTrack.offsetTime} ms`}</Text>
<Text size={100}>{`${t('Content Duration')}: ${selectedTrack.contentTime} ms`}</Text>
<div className="overflow-y-auto overflow-x-hidden" ref={contentPreviewRef}>
{selectedTrackId === commonStore.recordingTrackId
? commonStore.recordingContent
: selectedTrack.content}
</div>
</div>
</Card>
}
{
commonStore.platform !== 'web' &&
<div className="flex gap-2 items-end mx-auto">
{t('Current Instrument') + ':'}
{InstrumentTypeNameMap.map((name, i) =>
<Text key={name} style={{ whiteSpace: 'nowrap' }}
className={commonStore.instrumentType === i ? 'text-blue-600' : ''}
weight={commonStore.instrumentType === i ? 'bold' : 'regular'}
size={commonStore.instrumentType === i ? 300 : 100}
>{t(name)}</Text>)}
</div>
}
<DialogTrigger disableButtonEnhancement>
<Button icon={<MusicNote220Regular />} style={{ minHeight: '32px' }} onClick={() => {
flushMidiRecordingContent();
commonStore.setRecordingTrackId('');
commonStore.setPlayingTrackId('');
const timestamp = [];
const sortedTracks = commonStore.tracks.slice().sort((a, b) => a.offsetTime - b.offsetTime);
for (const track of sortedTracks) {
timestamp.push(track.offsetTime);
let accContentTime = 0;
for (const msg of track.rawContent) {
if (msg.messageType === 'ElapsedTime') {
accContentTime += msg.value;
timestamp.push(track.offsetTime + accContentTime);
}
}
}
const sortedTimestamp = timestamp.slice().sort((a, b) => a - b);
const globalMessages: MidiMessage[] = sortedTimestamp.reduce((messages, current, i) =>
[...messages, {
messageType: 'ElapsedTime',
value: current - (i === 0 ? 0 : sortedTimestamp[i - 1])
} as MidiMessage]
, [] as MidiMessage[]);
for (const track of sortedTracks) {
let currentTime = track.offsetTime;
let accContentTime = 0;
for (const msg of track.rawContent) {
if (msg.messageType === 'ElapsedTime') {
accContentTime += msg.value;
currentTime = track.offsetTime + accContentTime;
} else if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
const insertIndex = sortedTimestamp.findIndex(t => t >= currentTime);
globalMessages.splice(insertIndex + 1, 0, msg);
sortedTimestamp.splice(insertIndex + 1, 0, 0); // placeholder
}
}
}
const result = ('<pad> ' + globalMessages.map(midiMessageToToken).join('')).trim();
commonStore.setCompositionSubmittedPrompt(result);
setPrompt(result);
}}>
{t('Save to generation area')}
</Button>
</DialogTrigger>
</div>
);
});
export default AudiotrackEditor;

View File

@@ -16,8 +16,11 @@ import {
Attach16Regular,
Delete28Regular,
Dismiss16Regular,
Dismiss24Regular,
RecordStop28Regular,
Save28Regular
SaveRegular,
TextAlignJustify24Regular,
TextAlignJustifyRotate9024Regular
} from '@fluentui/react-icons';
import { CopyButton } from '../components/CopyButton';
import { ReadButton } from '../components/ReadButton';
@@ -25,49 +28,22 @@ import { toast } from 'react-toastify';
import { WorkHeader } from '../components/WorkHeader';
import { DialogButton } from '../components/DialogButton';
import { OpenFileFolder, OpenOpenFileDialog, OpenSaveFileDialog } from '../../wailsjs/go/backend_golang/App';
import { absPathAsset, bytesToReadable, toastWithButton } from '../utils';
import { PresetsButton } from './PresetsManager/PresetsButton';
import { absPathAsset, bytesToReadable, getServerRoot, setActivePreset, toastWithButton } from '../utils';
import { useMediaQuery } from 'usehooks-ts';
import { botName, ConversationMessage, MessageType, userName, welcomeUuid } from '../types/chat';
import { Labeled } from '../components/Labeled';
import { ValuedSlider } from '../components/ValuedSlider';
import { PresetsButton } from './PresetsManager/PresetsButton';
import { webOpenOpenFileDialog } from '../utils/web-file-operations';
export const userName = 'M E';
export const botName = 'A I';
let chatSseControllers: {
[id: string]: AbortController
} = {};
export const welcomeUuid = 'welcome';
export enum MessageType {
Normal,
Error
}
export type Side = 'left' | 'right'
export type Color = 'neutral' | 'brand' | 'colorful'
export type MessageItem = {
sender: string,
type: MessageType,
color: Color,
avatarImg?: string,
time: string,
content: string,
side: Side,
done: boolean
}
export type Conversation = {
[uuid: string]: MessageItem
}
export type Role = 'assistant' | 'user' | 'system';
export type ConversationMessage = {
role: Role;
content: string;
}
let chatSseControllers: { [id: string]: AbortController } = {};
const MoreUtilsButton: FC<{ uuid: string, setEditing: (editing: boolean) => void }> = observer(({
const MoreUtilsButton: FC<{
uuid: string,
setEditing: (editing: boolean) => void
}> = observer(({
uuid,
setEditing
}) => {
@@ -91,13 +67,15 @@ const MoreUtilsButton: FC<{ uuid: string, setEditing: (editing: boolean) => void
onClick={() => {
commonStore.conversationOrder.splice(commonStore.conversationOrder.indexOf(uuid), 1);
delete commonStore.conversation[uuid];
commonStore.setAttachment(uuid, null);
}} />
</MenuPopover>
</Menu>;
});
const ChatMessageItem: FC<{
uuid: string, onSubmit: (message: string | null, answerId: string | null,
uuid: string,
onSubmit: (message: string | null, answerId: string | null,
startUuid: string | null, endUuid: string | null, includeEndUuid: boolean) => void
}> = observer(({ uuid, onSubmit }) => {
const { t } = useTranslation();
@@ -157,7 +135,21 @@ const ChatMessageItem: FC<{
)}
>
{!editing ?
<MarkdownRender>{messageItem.content}</MarkdownRender> :
<div className="flex flex-col">
<MarkdownRender>{messageItem.content}</MarkdownRender>
{uuid in commonStore.attachments &&
<div className="flex grow">
<div className="grow" />
<ToolTipButton className="whitespace-nowrap"
text={
commonStore.attachments[uuid][0].name.replace(
new RegExp('(^[^\\.]{5})[^\\.]+'), '$1...')
}
desc={`${commonStore.attachments[uuid][0].name} (${bytesToReadable(commonStore.attachments[uuid][0].size)})`}
size="small" shape="circular" appearance="secondary" />
</div>
}
</div> :
<Textarea ref={textareaRef}
className="grow"
style={{ minWidth: 0 }}
@@ -202,11 +194,125 @@ const ChatMessageItem: FC<{
</div>;
});
const SidePanel: FC = observer(() => {
const [t] = useTranslation();
const mq = useMediaQuery('(min-width: 640px)');
const params = commonStore.chatParams;
return <div
className={classnames(
'flex flex-col gap-1 h-full flex-shrink-0 transition-width duration-300 ease-in-out',
commonStore.sidePanelCollapsed ? 'w-0' : (mq ? 'w-64' : 'w-full'),
!commonStore.sidePanelCollapsed && 'ml-1')
}>
<div className="flex m-1">
<div className="grow" />
<PresetsButton tab="Chat" size="medium" shape="circular" appearance="subtle" />
<Button size="medium" shape="circular" appearance="subtle" icon={<Dismiss24Regular />}
onClick={() => commonStore.setSidePanelCollapsed(true)}
/>
</div>
<div className="flex flex-col gap-1 overflow-x-hidden overflow-y-auto p-1">
<Labeled flex breakline label={t('Max Response Token')}
desc={t('By default, the maximum number of tokens that can be answered in a single response, it can be changed by the user by specifying API parameters.')}
content={
<ValuedSlider value={params.maxResponseToken} min={100} max={8100}
step={100}
input
onChange={(e, data) => {
commonStore.setChatParams({
maxResponseToken: data.value
});
}} />
} />
<Labeled flex breakline label={t('Temperature')}
desc={t('Sampling temperature, it\'s like giving alcohol to a model, the higher the stronger the randomness and creativity, while the lower, the more focused and deterministic it will be.')}
content={
<ValuedSlider value={params.temperature} min={0} max={2} step={0.1}
input
onChange={(e, data) => {
commonStore.setChatParams({
temperature: data.value
});
}} />
} />
<Labeled flex breakline label={t('Top_P')}
desc={t('Just like feeding sedatives to the model. Consider the results of the top n% probability mass, 0.1 considers the top 10%, with higher quality but more conservative, 1 considers all results, with lower quality but more diverse.')}
content={
<ValuedSlider value={params.topP} min={0} max={1} step={0.1} input
onChange={(e, data) => {
commonStore.setChatParams({
topP: data.value
});
}} />
} />
<Labeled flex breakline label={t('Presence Penalty')}
desc={t('Positive values penalize new tokens based on whether they appear in the text so far, increasing the model\'s likelihood to talk about new topics.')}
content={
<ValuedSlider value={params.presencePenalty} min={0} max={2}
step={0.1} input
onChange={(e, data) => {
commonStore.setChatParams({
presencePenalty: data.value
});
}} />
} />
<Labeled flex breakline label={t('Frequency Penalty')}
desc={t('Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model\'s likelihood to repeat the same line verbatim.')}
content={
<ValuedSlider value={params.frequencyPenalty} min={0} max={2}
step={0.1} input
onChange={(e, data) => {
commonStore.setChatParams({
frequencyPenalty: data.value
});
}} />
} />
</div>
<div className="grow" />
{/*<Button*/}
{/* icon={<FolderOpenVerticalRegular />}*/}
{/* onClick={() => {*/}
{/* }}>*/}
{/* {t('Load Conversation')}*/}
{/*</Button>*/}
<Button
icon={<SaveRegular />}
onClick={() => {
let savedContent: string = '';
const isWorldModel = commonStore.getCurrentModelConfig().modelParameters.modelName.toLowerCase().includes('world');
const user = isWorldModel ? 'User' : 'Bob';
const bot = isWorldModel ? 'Assistant' : 'Alice';
commonStore.conversationOrder.forEach((uuid) => {
if (uuid === welcomeUuid)
return;
const messageItem = commonStore.conversation[uuid];
if (messageItem.type !== MessageType.Error) {
savedContent += `${messageItem.sender === userName ? user : bot}: ${messageItem.content}\n\n`;
}
});
OpenSaveFileDialog('*.txt', 'conversation.txt', savedContent).then((path) => {
if (path)
toastWithButton(t('Conversation Saved'), t('Open'), () => {
OpenFileFolder(path, false);
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}}>
{t('Save Conversation')}
</Button>
</div>;
});
const ChatPanel: FC = observer(() => {
const { t } = useTranslation();
const bodyRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLTextAreaElement>(null);
const mq = useMediaQuery('(min-width: 640px)');
if (commonStore.sidePanelCollapsed === 'auto')
commonStore.setSidePanelCollapsed(!mq);
const currentConfig = commonStore.getCurrentModelConfig();
const apiParams = currentConfig.apiParameters;
const port = apiParams.apiPort;
@@ -228,7 +334,7 @@ const ChatPanel: FC = observer(() => {
color: 'colorful',
avatarImg: logo,
time: new Date().toISOString(),
content: t('Hello! I\'m RWKV, an open-source and commercially usable large language model.'),
content: commonStore.platform === 'web' ? t('Hello, what can I do for you?') : t('Hello! I\'m RWKV, an open-source and commercially usable large language model.'),
side: 'left',
done: true
}
@@ -245,7 +351,7 @@ const ChatPanel: FC = observer(() => {
e.stopPropagation();
if (e.type === 'click' || (e.keyCode === 13 && !e.shiftKey)) {
e.preventDefault();
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl) {
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl && commonStore.platform !== 'web') {
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
return;
}
@@ -275,6 +381,11 @@ const ChatPanel: FC = observer(() => {
commonStore.setConversation(commonStore.conversation);
commonStore.conversationOrder.push(newId);
commonStore.setConversationOrder(commonStore.conversationOrder);
if (commonStore.currentTempAttachment) {
commonStore.setAttachment(newId, [commonStore.currentTempAttachment]);
commonStore.setCurrentTempAttachment(null);
}
}
let startIndex = startUuid ? commonStore.conversationOrder.indexOf(startUuid) : 0;
@@ -282,20 +393,21 @@ const ChatPanel: FC = observer(() => {
let targetRange = commonStore.conversationOrder.slice(startIndex, endIndex);
const messages: ConversationMessage[] = [];
if (commonStore.attachmentContent) {
messages.push({
role: 'user',
content: t('The content of file') + ` "${commonStore.attachmentName}" `
+ t('is as follows. When replying to me, consider the file content and respond accordingly:')
+ '\n\n' + commonStore.attachmentContent
});
messages.push({ role: 'user', content: t('What\'s the file name') });
messages.push({ role: 'assistant', content: t('The file name is: ') + commonStore.attachmentName });
}
targetRange.forEach((uuid, index) => {
if (uuid === welcomeUuid)
return;
const messageItem = commonStore.conversation[uuid];
if (uuid in commonStore.attachments) {
const attachment = commonStore.attachments[uuid][0];
messages.push({
role: 'user',
content: t('The content of file') + ` "${attachment.name}" `
+ t('is as follows. When replying to me, consider the file content and respond accordingly:')
+ '\n\n' + attachment.content
});
messages.push({ role: 'user', content: t('What\'s the file name') });
messages.push({ role: 'assistant', content: t('The file name is: ') + attachment.name });
}
if (messageItem.done && messageItem.type === MessageType.Normal && messageItem.sender === userName) {
messages.push({ role: 'user', content: messageItem.content });
} else if (messageItem.done && messageItem.type === MessageType.Normal && messageItem.sender === botName) {
@@ -323,10 +435,8 @@ const ChatPanel: FC = observer(() => {
let answer = '';
const chatSseController = new AbortController();
chatSseControllers[answerId] = chatSseController;
fetchEventSource( // https://api.openai.com/v1/chat/completions || http://127.0.0.1:${port}/chat/completions
commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + '/v1/chat/completions' :
`http://127.0.0.1:${port}/chat/completions`,
fetchEventSource( // https://api.openai.com/v1/chat/completions || http://127.0.0.1:${port}/v1/chat/completions
getServerRoot(port, true) + '/v1/chat/completions',
{
method: 'POST',
headers: {
@@ -337,16 +447,20 @@ const ChatPanel: FC = observer(() => {
messages,
stream: true,
model: commonStore.settings.apiChatModelName, // 'gpt-3.5-turbo'
temperature: apiParams.temperature,
top_p: apiParams.topP,
user_name: commonStore.activePreset?.userName,
assistant_name: commonStore.activePreset?.assistantName,
presystem: commonStore.activePreset?.presystem
temperature: commonStore.chatParams.temperature,
top_p: commonStore.chatParams.topP,
presence_penalty: commonStore.chatParams.presencePenalty,
frequency_penalty: commonStore.chatParams.frequencyPenalty,
user_name: commonStore.activePreset?.userName || undefined,
assistant_name: commonStore.activePreset?.assistantName || undefined,
presystem: commonStore.activePreset?.presystem && undefined
}),
signal: chatSseController?.signal,
onmessage(e) {
scrollToBottom();
if (e.data.trim() === '[DONE]') {
if (answerId! in chatSseControllers)
delete chatSseControllers[answerId!];
commonStore.conversation[answerId!].done = true;
commonStore.conversation[answerId!].content = commonStore.conversation[answerId!].content.trim();
commonStore.setConversation(commonStore.conversation);
@@ -360,6 +474,8 @@ const ChatPanel: FC = observer(() => {
console.debug('json error', error);
return;
}
if (data.model)
commonStore.setLastModelName(data.model);
if (data.choices && Array.isArray(data.choices) && data.choices.length > 0) {
answer += data.choices[0]?.delta?.content || '';
commonStore.conversation[answerId!].content = answer;
@@ -381,6 +497,8 @@ const ChatPanel: FC = observer(() => {
console.log('Connection closed');
},
onerror(err) {
if (answerId! in chatSseControllers)
delete chatSseControllers[answerId!];
commonStore.conversation[answerId!].type = MessageType.Error;
commonStore.conversation[answerId!].done = true;
err = err.message || err;
@@ -395,172 +513,175 @@ const ChatPanel: FC = observer(() => {
}, []);
return (
<div className="flex flex-col w-full grow gap-4 pt-4 overflow-hidden">
<div ref={bodyRef} className="grow overflow-y-scroll overflow-x-hidden pr-2">
{commonStore.conversationOrder.map(uuid =>
<ChatMessageItem key={uuid} uuid={uuid} onSubmit={onSubmit} />
)}
</div>
<div className={classnames('flex items-end', mq ? 'gap-2' : '')}>
<PresetsButton tab="Chat" size={mq ? 'large' : 'small'} shape="circular" appearance="subtle" />
<DialogButton tooltip={t('Clear')}
icon={<Delete28Regular />}
size={mq ? 'large' : 'small'} shape="circular" appearance="subtle" title={t('Clear')}
contentText={t('Are you sure you want to clear the conversation? It cannot be undone.')}
onConfirm={() => {
if (generating) {
for (const id in chatSseControllers) {
chatSseControllers[id].abort();
<div className="flex h-full grow pt-4 overflow-hidden">
<div className="relative flex flex-col w-full grow gap-4 overflow-hidden">
<Button className="absolute top-1 right-1" size="medium" shape="circular" appearance="subtle"
style={{ zIndex: 1 }}
icon={commonStore.sidePanelCollapsed ? <TextAlignJustify24Regular /> : <TextAlignJustifyRotate9024Regular />}
onClick={() => commonStore.setSidePanelCollapsed(!commonStore.sidePanelCollapsed)} />
<div ref={bodyRef} className="grow overflow-y-scroll overflow-x-hidden pr-2">
{commonStore.conversationOrder.map(uuid =>
<ChatMessageItem key={uuid} uuid={uuid} onSubmit={onSubmit} />
)}
</div>
<div className={classnames('flex items-end', mq ? 'gap-2' : '')}>
<DialogButton tooltip={t('Clear')}
icon={<Delete28Regular />}
size={mq ? 'large' : 'small'} shape="circular" appearance="subtle" title={t('Clear')}
contentText={t('Are you sure you want to clear the conversation? It cannot be undone.')}
onConfirm={() => {
if (generating) {
for (const id in chatSseControllers) {
chatSseControllers[id].abort();
}
chatSseControllers = {};
}
chatSseControllers = {};
}
commonStore.setConversation({});
commonStore.setConversationOrder([]);
}} />
<div className="relative flex grow">
<Textarea
ref={inputRef}
style={{ minWidth: 0 }}
className="grow"
resize="vertical"
placeholder={t('Type your message here')!}
value={commonStore.currentInput}
onChange={(e) => commonStore.setCurrentInput(e.target.value)}
onKeyDown={handleKeyDownOrClick}
/>
<div className="absolute right-2 bottom-2">
{!commonStore.attachmentContent ?
<ToolTipButton
desc={commonStore.attachmentUploading ?
t('Uploading Attachment') :
t('Add An Attachment (Accepts pdf, txt)')}
icon={commonStore.attachmentUploading ?
<ArrowClockwise16Regular className="animate-spin" />
: <Attach16Regular />}
size="small" shape="circular" appearance="secondary"
onClick={() => {
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl) {
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
return;
}
if (commonStore.attachmentUploading)
return;
OpenOpenFileDialog('*.txt;*.pdf').then(async filePath => {
if (!filePath)
setActivePreset(commonStore.activePreset);
}} />
<div className="relative flex grow">
<Textarea
ref={inputRef}
style={{ minWidth: 0 }}
className="grow"
resize="vertical"
placeholder={t('Type your message here')!}
value={commonStore.currentInput}
onChange={(e) => commonStore.setCurrentInput(e.target.value)}
onKeyDown={handleKeyDownOrClick}
/>
<div className="absolute right-2 bottom-2">
{!commonStore.currentTempAttachment ?
<ToolTipButton
desc={commonStore.attachmentUploading ?
t('Processing Attachment') :
t('Add An Attachment (Accepts pdf, txt)')}
icon={commonStore.attachmentUploading ?
<ArrowClockwise16Regular className="animate-spin" />
: <Attach16Regular />}
size="small" shape="circular" appearance="secondary"
onClick={() => {
if (commonStore.attachmentUploading)
return;
commonStore.setAttachmentUploading(true);
// Both are slow. Communication between frontend and backend is slow. Use AssetServer Handler to read the file.
// const blob = new Blob([atob(info.content as unknown as string)]); // await fetch(`data:application/octet-stream;base64,${info.content}`).then(r => r.blob());
const blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
const attachmentName = filePath.split(/[\\/]/).pop();
const urlPath = `/file-to-text?file_name=${attachmentName}`;
const bodyForm = new FormData();
bodyForm.append('file_data', blob, attachmentName);
fetch(commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + urlPath :
`http://127.0.0.1:${port}${urlPath}`, {
method: 'POST',
body: bodyForm
}).then(async r => {
if (r.status === 200) {
const pages = (await r.json()).pages as any[];
let attachmentContent: string;
if (pages.length === 1)
attachmentContent = pages[0].page_content;
else
attachmentContent = pages.map((p, i) => `Page ${i + 1}:\n${p.page_content}`).join('\n\n');
commonStore.setAttachmentName(attachmentName!);
commonStore.setAttachmentSize(blob.size);
commonStore.setAttachmentContent(attachmentContent);
} else {
toast(r.statusText + '\n' + (await r.text()), {
type: 'error'
const filterPattern = '*.txt;*.pdf';
const setUploading = () => commonStore.setAttachmentUploading(true);
// actually, status of web platform is always Offline
if (commonStore.platform === 'web' || commonStore.status.status === ModelStatus.Offline || currentConfig.modelParameters.device === 'WebGPU') {
webOpenOpenFileDialog(filterPattern, setUploading).then(webReturn => {
if (webReturn.content)
commonStore.setCurrentTempAttachment(
{
name: webReturn.blob.name,
size: webReturn.blob.size,
content: webReturn.content
});
else
toast(t('File is empty'), {
type: 'info',
autoClose: 1000
});
}
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
}).finally(() => {
commonStore.setAttachmentUploading(false);
}
).catch(e => {
commonStore.setAttachmentUploading(false);
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}}
/> :
<div>
<ToolTipButton
text={
commonStore.attachmentName.replace(
new RegExp('(^[^\\.]{5})[^\\.]+'), '$1...')
}
desc={`${commonStore.attachmentName} (${bytesToReadable(commonStore.attachmentSize)})`}
size="small" shape="circular" appearance="secondary" />
<ToolTipButton desc={t('Remove Attachment')}
icon={<Dismiss16Regular />}
size="small" shape="circular" appearance="subtle"
onClick={() => {
commonStore.setAttachmentName('');
commonStore.setAttachmentSize(0);
commonStore.setAttachmentContent('');
}} />
</div>
}
</div>
</div>
<ToolTipButton desc={generating ? t('Stop') : t('Send')}
icon={generating ? <RecordStop28Regular /> : <ArrowCircleUp28Regular />}
size={mq ? 'large' : 'small'} shape="circular" appearance="subtle"
onClick={(e) => {
if (generating) {
for (const id in chatSseControllers) {
chatSseControllers[id].abort();
commonStore.conversation[id].type = MessageType.Error;
commonStore.conversation[id].done = true;
}
chatSseControllers = {};
commonStore.setConversation(commonStore.conversation);
commonStore.setConversationOrder([...commonStore.conversationOrder]);
} else {
handleKeyDownOrClick(e);
}
}} />
<ToolTipButton desc={t('Save')}
icon={<Save28Regular />}
size={mq ? 'large' : 'small'} shape="circular" appearance="subtle"
onClick={() => {
let savedContent: string = '';
const isWorldModel = commonStore.getCurrentModelConfig().modelParameters.modelName.toLowerCase().includes('world');
const user = isWorldModel ? 'User' : 'Bob';
const bot = isWorldModel ? 'Assistant' : 'Alice';
commonStore.conversationOrder.forEach((uuid) => {
if (uuid === welcomeUuid)
return;
const messageItem = commonStore.conversation[uuid];
if (messageItem.type !== MessageType.Error) {
savedContent += `${messageItem.sender === userName ? user : bot}: ${messageItem.content}\n\n`;
}
});
});
} else {
OpenOpenFileDialog(filterPattern).then(async filePath => {
if (!filePath)
return;
OpenSaveFileDialog('*.txt', 'conversation.txt', savedContent).then((path) => {
if (path)
toastWithButton(t('Conversation Saved'), t('Open'), () => {
OpenFileFolder(path, false);
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}} />
setUploading();
// Both are slow. Communication between frontend and backend is slow. Use AssetServer Handler to read the file.
// const blob = new Blob([atob(info.content as unknown as string)]); // await fetch(`data:application/octet-stream;base64,${info.content}`).then(r => r.blob());
const blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
const attachmentName = filePath.split(/[\\/]/).pop();
const urlPath = `/file-to-text?file_name=${attachmentName}`;
const bodyForm = new FormData();
bodyForm.append('file_data', blob, attachmentName);
fetch(getServerRoot(port) + urlPath, {
method: 'POST',
body: bodyForm
}).then(async r => {
if (r.status === 200) {
const pages = (await r.json()).pages as any[];
let attachmentContent: string;
if (pages.length === 1)
attachmentContent = pages[0].page_content;
else
attachmentContent = pages.map((p, i) => `Page ${i + 1}:\n${p.page_content}`).join('\n\n');
if (attachmentContent)
commonStore.setCurrentTempAttachment(
{
name: attachmentName!,
size: blob.size,
content: attachmentContent!
});
else
toast(t('File is empty'), {
type: 'info',
autoClose: 1000
});
} else {
toast(r.statusText + '\n' + (await r.text()), {
type: 'error'
});
}
}
).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
}).finally(() => {
commonStore.setAttachmentUploading(false);
});
}).catch(e => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}
}}
/> :
<div>
<ToolTipButton
text={
commonStore.currentTempAttachment.name.replace(
new RegExp('(^[^\\.]{5})[^\\.]+'), '$1...')
}
desc={`${commonStore.currentTempAttachment.name} (${bytesToReadable(commonStore.currentTempAttachment.size)})`}
size="small" shape="circular" appearance="secondary" />
<ToolTipButton desc={t('Remove Attachment')}
icon={<Dismiss16Regular />}
size="small" shape="circular" appearance="subtle"
onClick={() => {
commonStore.setCurrentTempAttachment(null);
}} />
</div>
}
</div>
</div>
<ToolTipButton desc={generating ? t('Stop') : t('Send')}
icon={generating ? <RecordStop28Regular /> : <ArrowCircleUp28Regular />}
size={mq ? 'large' : 'small'} shape="circular" appearance="subtle"
onClick={(e) => {
if (generating) {
for (const id in chatSseControllers) {
chatSseControllers[id].abort();
commonStore.conversation[id].type = MessageType.Error;
commonStore.conversation[id].done = true;
}
chatSseControllers = {};
commonStore.setConversation(commonStore.conversation);
commonStore.setConversationOrder([...commonStore.conversationOrder]);
} else {
handleKeyDownOrClick(e);
}
}} />
</div>
</div>
<SidePanel />
</div>
);
});
export const Chat: FC = observer(() => {
const Chat: FC = observer(() => {
return (
<div className="flex flex-col gap-1 p-2 h-full overflow-hidden">
<WorkHeader />
@@ -568,3 +689,5 @@ export const Chat: FC = observer(() => {
</div>
);
});
export default Chat;

View File

@@ -5,7 +5,6 @@ import { Button, Dropdown, Input, Option, Textarea } from '@fluentui/react-compo
import { Labeled } from '../components/Labeled';
import { ValuedSlider } from '../components/ValuedSlider';
import { useTranslation } from 'react-i18next';
import { ApiParameters } from './Configs';
import commonStore, { ModelStatus } from '../stores/commonStore';
import { fetchEventSource } from '@microsoft/fetch-event-source';
import { toast } from 'react-toastify';
@@ -14,18 +13,8 @@ import { PresetsButton } from './PresetsManager/PresetsButton';
import { ToolTipButton } from '../components/ToolTipButton';
import { ArrowSync20Regular } from '@fluentui/react-icons';
import { defaultPresets } from './defaultConfigs';
export type CompletionParams = Omit<ApiParameters, 'apiPort'> & {
stop: string,
injectStart: string,
injectEnd: string
};
export type CompletionPreset = {
name: string,
prompt: string,
params: CompletionParams
}
import { CompletionParams, CompletionPreset } from '../types/completion';
import { getServerRoot } from '../utils';
let completionSseController: AbortController | null = null;
@@ -40,8 +29,10 @@ const CompletionPanel: FC = observer(() => {
};
useEffect(() => {
if (inputRef.current)
if (inputRef.current) {
inputRef.current.style.height = '100%';
inputRef.current.style.maxHeight = '100%';
}
scrollToBottom();
}, []);
@@ -80,7 +71,7 @@ const CompletionPanel: FC = observer(() => {
const onSubmit = (prompt: string) => {
commonStore.setCompletionSubmittedPrompt(prompt);
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl) {
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl && commonStore.platform !== 'web') {
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
commonStore.setCompletionGenerating(false);
return;
@@ -90,10 +81,8 @@ const CompletionPanel: FC = observer(() => {
let answer = '';
completionSseController = new AbortController();
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/completions
commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + '/v1/completions' :
`http://127.0.0.1:${port}/completions`,
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/v1/completions
getServerRoot(port, true) + '/v1/completions',
{
method: 'POST',
headers: {
@@ -125,6 +114,8 @@ const CompletionPanel: FC = observer(() => {
console.debug('json error', error);
return;
}
if (data.model)
commonStore.setLastModelName(data.model);
if (data.choices && Array.isArray(data.choices) && data.choices.length > 0) {
answer += data.choices[0]?.text || data.choices[0]?.delta?.content || '';
setPrompt(prompt + answer.replace(/\s+$/, '') + params.injectEnd.replaceAll('\\n', '\n'));
@@ -186,7 +177,7 @@ const CompletionPanel: FC = observer(() => {
desc={t('By default, the maximum number of tokens that can be answered in a single response, it can be changed by the user by specifying API parameters.')}
content={
<ValuedSlider value={params.maxResponseToken} min={100} max={8100}
step={400}
step={100}
input
onChange={(e, data) => {
setParams({
@@ -269,7 +260,7 @@ const CompletionPanel: FC = observer(() => {
} />
</div>
<div className="grow" />
<div className="flex justify-between gap-2">
<div className="hidden justify-between gap-2 sm:flex">
<Button className="grow" onClick={() => {
const newPrompt = prompt.replace(/\n+\ /g, '\n').split('\n').map((line) => line.trim()).join('\n');
setPrompt(newPrompt);
@@ -303,7 +294,7 @@ const CompletionPanel: FC = observer(() => {
);
});
export const Completion: FC = observer(() => {
const Completion: FC = observer(() => {
return (
<div className="flex flex-col gap-1 p-2 h-full overflow-hidden">
<WorkHeader />
@@ -311,3 +302,5 @@ export const Completion: FC = observer(() => {
</div>
);
});
export default Completion;

View File

@@ -1,7 +1,8 @@
import 'html-midi-player';
import React, { FC, useEffect, useRef } from 'react';
import { observer } from 'mobx-react-lite';
import { WorkHeader } from '../components/WorkHeader';
import { Button, Checkbox, Textarea } from '@fluentui/react-components';
import { Button, Checkbox, Dropdown, Option, Textarea } from '@fluentui/react-components';
import { Labeled } from '../components/Labeled';
import { ValuedSlider } from '../components/ValuedSlider';
import { useTranslation } from 'react-i18next';
@@ -15,24 +16,25 @@ import { PlayerElement, VisualizerElement } from 'html-midi-player';
import * as mm from '@magenta/music/esm/core.js';
import { NoteSequence } from '@magenta/music/esm/protobuf.js';
import { defaultCompositionPrompt } from './defaultConfigs';
import { FileExists, OpenFileFolder, OpenSaveFileDialogBytes } from '../../wailsjs/go/backend_golang/App';
import { toastWithButton } from '../utils';
export type CompositionParams = {
prompt: string,
maxResponseToken: number,
temperature: number,
topP: number,
autoPlay: boolean,
useLocalSoundFont: boolean,
midi: ArrayBuffer | null,
ns: NoteSequence | null
}
import {
CloseMidiPort,
FileExists,
OpenFileFolder,
OpenMidiPort,
OpenSaveFileDialogBytes,
SaveFile,
StartFile
} from '../../wailsjs/go/backend_golang/App';
import { getServerRoot, getSoundFont, toastWithButton } from '../utils';
import { CompositionParams } from '../types/composition';
import { useMediaQuery } from 'usehooks-ts';
import { AudiotrackButton } from './AudiotrackManager/AudiotrackButton';
let compositionSseController: AbortController | null = null;
const CompositionPanel: FC = observer(() => {
const { t } = useTranslation();
const mq = useMediaQuery('(min-width: 640px)');
const inputRef = useRef<HTMLTextAreaElement>(null);
const port = commonStore.getCurrentModelConfig().apiParameters.apiPort;
const visualizerRef = useRef<VisualizerElement>(null);
@@ -71,26 +73,16 @@ const CompositionPanel: FC = observer(() => {
};
const setSoundFont = async () => {
let soundUrl: string;
if (commonStore.compositionParams.useLocalSoundFont)
soundUrl = 'assets/sound-font';
else
soundUrl = !commonStore.settings.giteeUpdatesSource ?
`https://raw.githubusercontent.com/josStorer/sgm_plus/master` :
`https://gitee.com/josc146/sgm_plus/raw/master`;
const fallbackUrl = 'https://cdn.jsdelivr.net/gh/josstorer/sgm_plus';
await fetch(soundUrl + '/soundfont.json').then(r => {
if (!r.ok)
soundUrl = fallbackUrl;
}).catch(() => soundUrl = fallbackUrl);
if (playerRef.current) {
playerRef.current.soundFont = soundUrl;
playerRef.current.soundFont = await getSoundFont();
}
};
useEffect(() => {
if (inputRef.current)
if (inputRef.current) {
inputRef.current.style.height = '100%';
inputRef.current.style.maxHeight = '100%';
}
scrollToBottom();
if (playerRef.current && visualizerRef.current) {
@@ -108,10 +100,40 @@ const CompositionPanel: FC = observer(() => {
}
}, []);
const externalPlayListener = () => {
const params = commonStore.compositionParams;
const saveAndPlay = async (midi: ArrayBuffer, path: string) => {
await SaveFile(path, Array.from(new Uint8Array(midi)));
StartFile(path);
};
if (params.externalPlay) {
if (params.midi) {
setTimeout(() => {
playerRef.current?.stop();
});
saveAndPlay(params.midi, './midi/last.mid').catch((e: string) => {
if (e.includes('being used'))
saveAndPlay(params.midi!, './midi/last-2.mid');
});
}
}
};
useEffect(() => {
playerRef.current?.addEventListener('start', externalPlayListener);
return () => {
playerRef.current?.removeEventListener('start', externalPlayListener);
};
}, [params.externalPlay]);
useEffect(() => {
if (!(commonStore.activeMidiDeviceIndex in commonStore.midiPorts)) {
commonStore.setActiveMidiDeviceIndex(-1);
CloseMidiPort();
}
}, [commonStore.midiPorts]);
const generateNs = (autoPlay: boolean) => {
fetch(commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + '/text-to-midi' :
`http://127.0.0.1:${port}/text-to-midi`, {
fetch(getServerRoot(port) + '/text-to-midi', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
@@ -128,7 +150,12 @@ const CompositionPanel: FC = observer(() => {
});
updateNs(ns);
if (autoPlay) {
playerRef.current?.start();
if (commonStore.compositionParams.externalPlay)
externalPlayListener();
else
setTimeout(() => {
playerRef.current?.start();
});
}
});
});
@@ -137,7 +164,7 @@ const CompositionPanel: FC = observer(() => {
const onSubmit = (prompt: string) => {
commonStore.setCompositionSubmittedPrompt(prompt);
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl) {
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl && commonStore.platform !== 'web') {
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
commonStore.setCompositionGenerating(false);
return;
@@ -145,10 +172,8 @@ const CompositionPanel: FC = observer(() => {
let answer = '';
compositionSseController = new AbortController();
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/completions
commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + '/v1/completions' :
`http://127.0.0.1:${port}/completions`,
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/v1/completions
getServerRoot(port, true) + '/v1/completions',
{
method: 'POST',
headers: {
@@ -178,6 +203,8 @@ const CompositionPanel: FC = observer(() => {
console.debug('json error', error);
return;
}
if (data.model)
commonStore.setLastModelName(data.model);
if (data.choices && Array.isArray(data.choices) && data.choices.length > 0) {
answer += data.choices[0]?.text || data.choices[0]?.delta?.content || '';
setPrompt(prompt + answer.replace(/\s+$/, ''));
@@ -217,62 +244,110 @@ const CompositionPanel: FC = observer(() => {
setPrompt(e.target.value);
}}
/>
<div className="flex flex-col gap-1 max-h-48 sm:max-w-sm sm:max-h-full overflow-x-hidden overflow-y-auto p-1">
<Labeled flex breakline label={t('Max Response Token')}
desc={t('By default, the maximum number of tokens that can be answered in a single response, it can be changed by the user by specifying API parameters.')}
content={
<ValuedSlider value={params.maxResponseToken} min={100} max={4100}
step={100}
input
onChange={(e, data) => {
<div className="flex flex-col gap-1 max-h-48 sm:max-w-sm sm:max-h-full">
<div className="flex flex-col gap-1 grow overflow-x-hidden overflow-y-auto p-1">
<Labeled flex breakline label={t('Max Response Token')}
desc={t('By default, the maximum number of tokens that can be answered in a single response, it can be changed by the user by specifying API parameters.')}
content={
<ValuedSlider value={params.maxResponseToken} min={100} max={4100}
step={100}
input
onChange={(e, data) => {
setParams({
maxResponseToken: data.value
});
}} />
} />
<Labeled flex breakline label={t('Temperature')}
desc={t('Sampling temperature, it\'s like giving alcohol to a model, the higher the stronger the randomness and creativity, while the lower, the more focused and deterministic it will be.')}
content={
<ValuedSlider value={params.temperature} min={0} max={2} step={0.1}
input
onChange={(e, data) => {
setParams({
temperature: data.value
});
}} />
} />
<Labeled flex breakline label={t('Top_P')}
desc={t('Just like feeding sedatives to the model. Consider the results of the top n% probability mass, 0.1 considers the top 10%, with higher quality but more conservative, 1 considers all results, with lower quality but more diverse.')}
content={
<ValuedSlider value={params.topP} min={0} max={1} step={0.1} input
onChange={(e, data) => {
setParams({
topP: data.value
});
}} />
} />
<div className="grow" />
{
commonStore.platform !== 'web' &&
<Checkbox className="select-none"
size="large" label={t('Use Local Sound Font')} checked={params.useLocalSoundFont}
onChange={async (_, data) => {
if (data.checked) {
if (!await FileExists('assets/sound-font/accordion/instrument.json')) {
toast(t('Failed to load local sound font, please check if the files exist - assets/sound-font'),
{ type: 'warning' });
return;
}
}
setParams({
maxResponseToken: data.value
useLocalSoundFont: data.checked as boolean
});
setSoundFont();
}} />
}
{
commonStore.platform === 'windows' &&
<Checkbox className="select-none"
size="large" label={t('Play With External Player')} checked={params.externalPlay}
onChange={async (_, data) => {
setParams({
externalPlay: data.checked as boolean
});
}} />
} />
<Labeled flex breakline label={t('Temperature')}
desc={t('Sampling temperature, it\'s like giving alcohol to a model, the higher the stronger the randomness and creativity, while the lower, the more focused and deterministic it will be.')}
content={
<ValuedSlider value={params.temperature} min={0} max={2} step={0.1}
input
onChange={(e, data) => {
setParams({
temperature: data.value
});
}} />
} />
<Labeled flex breakline label={t('Top_P')}
desc={t('Just like feeding sedatives to the model. Consider the results of the top n% probability mass, 0.1 considers the top 10%, with higher quality but more conservative, 1 considers all results, with lower quality but more diverse.')}
content={
<ValuedSlider value={params.topP} min={0} max={1} step={0.1} input
onChange={(e, data) => {
setParams({
topP: data.value
});
}} />
} />
<div className="grow" />
<Checkbox className="select-none"
size="large" label={t('Use Local Sound Font')} checked={params.useLocalSoundFont}
onChange={async (_, data) => {
if (data.checked) {
if (!await FileExists('assets/sound-font/accordion/instrument.json')) {
toast(t('Failed to load local sound font, please check if the files exist - assets/sound-font'),
{ type: 'warning' });
return;
}
}
}
<Checkbox className="select-none"
size="large" label={t('Auto Play At The End')} checked={params.autoPlay} onChange={(_, data) => {
setParams({
useLocalSoundFont: data.checked as boolean
autoPlay: data.checked as boolean
});
setSoundFont();
}} />
<Checkbox className="select-none"
size="large" label={t('Auto Play At The End')} checked={params.autoPlay} onChange={(_, data) => {
setParams({
autoPlay: data.checked as boolean
});
}} />
<Labeled flex breakline label={t('MIDI Input')}
desc={t('Select the MIDI input device to be used.')}
content={
<div className="flex flex-col gap-1">
{
commonStore.platform !== 'web' &&
<Dropdown style={{ minWidth: 0 }}
value={(commonStore.activeMidiDeviceIndex === -1 || !(commonStore.activeMidiDeviceIndex in commonStore.midiPorts))
? t('None')!
: commonStore.midiPorts[commonStore.activeMidiDeviceIndex].name}
selectedOptions={[commonStore.activeMidiDeviceIndex.toString()]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
const index = Number(data.optionValue);
let action = (index === -1)
? () => CloseMidiPort()
: () => OpenMidiPort(index);
action().then(() => {
commonStore.setActiveMidiDeviceIndex(index);
}).catch((e) => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
}
}}>
<Option value={'-1'}>{t('None')!}</Option>
{commonStore.midiPorts.map((p, i) =>
<Option key={i} value={i.toString()}>{p.name}</Option>)
}
</Dropdown>
}
<AudiotrackButton setPrompt={setPrompt} />
</div>
} />
</div>
<div className="flex justify-between gap-2">
<ToolTipButton desc={t('Regenerate')} icon={<ArrowSync20Regular />} onClick={() => {
compositionSseController?.abort();
@@ -311,7 +386,7 @@ const CompositionPanel: FC = observer(() => {
ref={playerRef}
style={{ width: '100%' }}
/>
<Button icon={<Save28Regular />}
<Button icon={<Save28Regular />} size={mq ? 'large' : 'medium'} appearance={mq ? 'secondary' : 'subtle'}
onClick={() => {
if (params.midi) {
OpenSaveFileDialogBytes('*.mid', 'music.mid', Array.from(new Uint8Array(params.midi))).then((path) => {
@@ -327,7 +402,7 @@ const CompositionPanel: FC = observer(() => {
}
}}
>
{t('Save')}
{mq ? t('Save') : ''}
</Button>
</div>
</div>
@@ -335,7 +410,7 @@ const CompositionPanel: FC = observer(() => {
);
});
export const Composition: FC = observer(() => {
const Composition: FC = observer(() => {
return (
<div className="flex flex-col gap-1 p-2 h-full overflow-hidden">
<WorkHeader />
@@ -343,3 +418,5 @@ export const Composition: FC = observer(() => {
</div>
);
});
export default Composition;

View File

@@ -8,12 +8,13 @@ import {
Input,
Label,
Option,
PresenceBadge,
Select,
Switch,
Text
} from '@fluentui/react-components';
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
import React, { FC, useEffect, useRef } from 'react';
import React, { FC, useCallback, useEffect, useRef } from 'react';
import { Section } from '../components/Section';
import { Labeled } from '../components/Labeled';
import { ToolTipButton } from '../components/ToolTipButton';
@@ -26,48 +27,41 @@ import { Page } from '../components/Page';
import { useNavigate } from 'react-router';
import { RunButton } from '../components/RunButton';
import { updateConfig } from '../apis';
import { ConvertModel, ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
import { checkDependencies, getStrategy } from '../utils';
import { getStrategy } from '../utils';
import { useTranslation } from 'react-i18next';
import { WindowShow } from '../../wailsjs/runtime/runtime';
import strategyImg from '../assets/images/strategy.jpg';
import strategyZhImg from '../assets/images/strategy_zh.jpg';
import { ResetConfigsButton } from '../components/ResetConfigsButton';
import { useMediaQuery } from 'usehooks-ts';
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
export type ApiParameters = {
apiPort: number
maxResponseToken: number;
temperature: number;
topP: number;
presencePenalty: number;
frequencyPenalty: number;
}
const ConfigSelector: FC<{
selectedIndex: number,
updateSelectedIndex: (i: number) => void
}> = observer(({ selectedIndex, updateSelectedIndex }) => {
return (
<Dropdown style={{ minWidth: 0 }} className="grow" value={commonStore.modelConfigs[selectedIndex].name}
selectedOptions={[selectedIndex.toString()]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
updateSelectedIndex(Number(data.optionValue));
}
}}>
{commonStore.modelConfigs.map((config, index) => <Option key={index} value={index.toString()}
text={config.name}>
<div className="flex justify-between grow">
{config.name}
{commonStore.modelSourceList.find(item => item.name === config.modelParameters.modelName)?.isComplete
&& <PresenceBadge status="available" />}
</div>
</Option>
)}
</Dropdown>
);
});
export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32';
export type ModelParameters = {
// different models can not have the same name
modelName: string;
device: Device;
precision: Precision;
storedLayers: number;
maxStoredLayers: number;
useCustomCuda?: boolean;
customStrategy?: string;
useCustomTokenizer?: boolean;
customTokenizer?: string;
}
export type ModelConfig = {
// different configs can have the same name
name: string;
apiParameters: ApiParameters
modelParameters: ModelParameters
}
export const Configs: FC = observer(() => {
const Configs: FC = observer(() => {
const { t } = useTranslation();
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
@@ -82,13 +76,13 @@ export const Configs: FC = observer(() => {
(advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0';
}, []);
const updateSelectedIndex = (newIndex: number) => {
const updateSelectedIndex = useCallback((newIndex: number) => {
setSelectedIndex(newIndex);
setSelectedConfig(commonStore.modelConfigs[newIndex]);
// if you don't want to update the config used by the current startup in real time, comment out this line
commonStore.setCurrentConfigIndex(newIndex);
};
}, []);
const setSelectedConfigName = (newName: string) => {
setSelectedConfig({ ...selectedConfig, name: newName });
@@ -128,17 +122,7 @@ export const Configs: FC = observer(() => {
<Page title={t('Configs')} content={
<div className="flex flex-col gap-2 overflow-hidden">
<div className="flex gap-2 items-center">
<Dropdown style={{ minWidth: 0 }} className="grow" value={commonStore.modelConfigs[selectedIndex].name}
selectedOptions={[selectedIndex.toString()]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
updateSelectedIndex(Number(data.optionValue));
}
}}>
{commonStore.modelConfigs.map((config, index) =>
<Option key={index} value={index.toString()}>{config.name}</Option>
)}
</Dropdown>
<ConfigSelector selectedIndex={selectedIndex} updateSelectedIndex={updateSelectedIndex} />
<ToolTipButton desc={t('New Config')} icon={<AddCircle20Regular />} onClick={() => {
commonStore.createModelConfig();
updateSelectedIndex(commonStore.modelConfigs.length - 1);
@@ -181,7 +165,7 @@ export const Configs: FC = observer(() => {
desc={t('By default, the maximum number of tokens that can be answered in a single response, it can be changed by the user by specifying API parameters.')}
content={
<ValuedSlider value={selectedConfig.apiParameters.maxResponseToken} min={100} max={8100}
step={400}
step={100}
input
onChange={(e, data) => {
setSelectedConfigApiParams({
@@ -263,81 +247,16 @@ export const Configs: FC = observer(() => {
} />
{
selectedConfig.modelParameters.device !== 'WebGPU' ?
<ToolTipButton text={t('Convert')}
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
onClick={async () => {
if (commonStore.platform === 'darwin') {
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
return;
} else if (commonStore.platform === 'linux') {
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
return;
}
const ok = await checkDependencies(navigate);
if (!ok)
return;
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
if (await FileExists(modelPath)) {
const strategy = getStrategy(selectedConfig);
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
if (!await FileExists(newModelPath + '.pth')) {
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
}
}).catch(e => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
else
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
});
setTimeout(WindowShow, 1000);
} else {
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
}
}} /> :
<ToolTipButton text={t('Convert To Safe Tensors Format')}
(selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' ?
<ToolTipButton text={t('Convert')}
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
onClick={() => convertModel(selectedConfig, navigate)} /> :
<ToolTipButton text={t('Convert To GGML Format')}
desc=""
onClick={() => convertToGGML(selectedConfig, navigate)} />)
: <ToolTipButton text={t('Convert To Safe Tensors Format')}
desc=""
onClick={async () => {
if (commonStore.platform === 'darwin') {
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_safetensors.py)', { type: 'info' });
return;
} else if (commonStore.platform === 'linux') {
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_safetensors.py)', { type: 'info' });
return;
}
const ok = await checkDependencies(navigate);
if (!ok)
return;
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
if (await FileExists(modelPath)) {
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
const newModelPath = modelPath.replace(/\.pth$/, '.st');
ConvertSafetensors(commonStore.settings.customPythonPath, modelPath, newModelPath).then(async () => {
if (!await FileExists(newModelPath)) {
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
}
}).catch(e => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
else
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
});
setTimeout(WindowShow, 1000);
} else {
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
}
}} />
onClick={() => convertToSt(selectedConfig)} />
}
<Labeled label={t('Strategy')} content={
<Dropdown style={{ minWidth: 0 }} className="grow" value={t(selectedConfig.modelParameters.device)!}
@@ -350,6 +269,7 @@ export const Configs: FC = observer(() => {
}
}}>
<Option value="CPU">CPU</Option>
<Option value="CPU (rwkv.cpp)">{t('CPU (rwkv.cpp, Faster)')!}</Option>
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
<Option value="CUDA">CUDA</Option>
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
@@ -359,7 +279,7 @@ export const Configs: FC = observer(() => {
} />
{
selectedConfig.modelParameters.device !== 'Custom' && <Labeled label={t('Precision')}
desc={t('int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality, and fp32 has the best quality.')}
desc={t('int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality.')}
content={
<Dropdown style={{ minWidth: 0 }} className="grow"
value={selectedConfig.modelParameters.precision}
@@ -371,9 +291,13 @@ export const Configs: FC = observer(() => {
});
}
}}>
<Option>fp16</Option>
<Option>int8</Option>
{selectedConfig.modelParameters.device !== 'WebGPU' && <Option>fp32</Option>}
{selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
<Option>fp16</Option>}
{selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && <Option>int8</Option>}
{selectedConfig.modelParameters.device === 'WebGPU' && <Option>nf4</Option>}
{selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && selectedConfig.modelParameters.device !== 'WebGPU' &&
<Option>fp32</Option>}
{selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && <Option>Q5_1</Option>}
</Dropdown>
} />
}
@@ -454,7 +378,7 @@ export const Configs: FC = observer(() => {
});
}} />
<Input className="grow"
placeholder={t('Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json)')!}
placeholder={t('Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json or rwkv_vocab_v20230424.txt)')!}
value={selectedConfig.modelParameters.customTokenizer}
onChange={(e, data) => {
setSelectedConfigModelParams({
@@ -472,9 +396,23 @@ export const Configs: FC = observer(() => {
/>
</div>
<div className="flex flex-row-reverse sm:fixed bottom-2 right-2">
<RunButton onClickRun={onClickSave} />
<div className="flex gap-2">
{selectedConfig.modelParameters.device !== 'WebGPU'
&& <Checkbox className="select-none"
size="large" label={t('Enable WebUI')}
checked={selectedConfig.enableWebUI}
onChange={(_, data) => {
setSelectedConfig({
...selectedConfig,
enableWebUI: data.checked as boolean
});
}} />}
<RunButton onClickRun={onClickSave} />
</div>
</div>
</div>
} />
);
});
export default Configs;

View File

@@ -9,19 +9,7 @@ import { ToolTipButton } from '../components/ToolTipButton';
import { Folder20Regular, Pause20Regular, Play20Regular } from '@fluentui/react-icons';
import { AddToDownloadList, OpenFileFolder, PauseDownload } from '../../wailsjs/go/backend_golang/App';
export type DownloadStatus = {
name: string;
path: string;
url: string;
transferred: number;
size: number;
speed: number;
progress: number;
downloading: boolean;
done: boolean;
}
export const Downloads: FC = observer(() => {
const Downloads: FC = observer(() => {
const { t } = useTranslation();
const finishedModelsLen = commonStore.downloadList.filter((status) => status.done && status.name.endsWith('.pth')).length;
useEffect(() => {
@@ -91,3 +79,5 @@ export const Downloads: FC = observer(() => {
} />
);
});
export default Downloads;

View File

@@ -1,11 +1,13 @@
import { CompoundButton, Link, Text } from '@fluentui/react-components';
import React, { FC, ReactElement } from 'react';
import React, { FC } from 'react';
import banner from '../assets/images/banner.jpg';
import {
Chat20Regular,
ClipboardEdit20Regular,
DataUsageSettings20Regular,
DocumentSettings20Regular
DocumentSettings20Regular,
MusicNote220Regular,
Settings20Regular
} from '@fluentui/react-icons';
import { useNavigate } from 'react-router';
import { observer } from 'mobx-react-lite';
@@ -14,21 +16,13 @@ import manifest from '../../../manifest.json';
import { BrowserOpenURL } from '../../wailsjs/runtime';
import { useTranslation } from 'react-i18next';
import { ConfigSelector } from '../components/ConfigSelector';
import MarkdownRender from '../components/MarkdownRender';
import commonStore from '../stores/commonStore';
import { Completion } from './Completion';
import { ResetConfigsButton } from '../components/ResetConfigsButton';
import { AdvancedGeneralSettings } from './Settings';
import { NavCard } from '../types/home';
import { LazyImportComponent } from '../components/LazyImportComponent';
export type IntroductionContent = { [lang: string]: string }
type NavCard = {
label: string;
desc: string;
path: string;
icon: ReactElement;
};
const navCards: NavCard[] = [
const clientNavCards: NavCard[] = [
{
label: 'Chat',
desc: 'Go to chat page',
@@ -55,7 +49,36 @@ const navCards: NavCard[] = [
}
];
export const Home: FC = observer(() => {
const webNavCards: NavCard[] = [
{
label: 'Chat',
desc: 'Go to chat page',
path: '/chat',
icon: <Chat20Regular />
},
{
label: 'Completion',
desc: 'Writer, Translator, Role-playing',
path: '/completion',
icon: <ClipboardEdit20Regular />
},
{
label: 'Composition',
desc: '',
path: '/composition',
icon: <MusicNote220Regular />
},
{
label: 'Settings',
desc: '',
path: '/settings',
icon: <Settings20Regular />
}
];
const MarkdownRender = React.lazy(() => import('../components/MarkdownRender'));
const Home: FC = observer(() => {
const { t } = useTranslation();
const navigate = useNavigate();
const lang: string = commonStore.settings.language;
@@ -64,39 +87,64 @@ export const Home: FC = observer(() => {
navigate({ pathname: path });
};
return (
<div className="flex flex-col justify-between h-full">
<img className="rounded-xl select-none hidden sm:block"
style={{ maxHeight: '40%', margin: '0 auto' }} src={banner} />
<div className="flex flex-col gap-2">
<Text size={600} weight="medium">{t('Introduction')}</Text>
<div className="h-40 overflow-y-auto overflow-x-hidden p-1">
<MarkdownRender>
{lang in commonStore.introduction ? commonStore.introduction[lang] : commonStore.introduction['en']}
</MarkdownRender>
return commonStore.platform === 'web' ?
(
<div className="flex flex-col gap-2 h-full overflow-x-hidden overflow-y-auto">
<img className="rounded-xl select-none object-cover grow"
style={{ maxHeight: '40%' }} src={banner} />
<div className="grow"></div>
<div className="grid grid-cols-2 sm:grid-cols-4 gap-5">
{webNavCards.map(({ label, path, icon, desc }, index) => (
<CompoundButton icon={icon} secondaryContent={t(desc)} key={`${path}-${index}`} value={path}
size="large" onClick={() => onClickNavCard(path)}>
{t(label)}
</CompoundButton>
))}
</div>
</div>
<div className="grid grid-cols-2 sm:grid-cols-4 gap-5">
{navCards.map(({ label, path, icon, desc }, index) => (
<CompoundButton icon={icon} secondaryContent={t(desc)} key={`${path}-${index}`} value={path}
size="large" onClick={() => onClickNavCard(path)}>
{t(label)}
</CompoundButton>
))}
</div>
<div className="flex flex-col gap-2">
<div className="flex flex-row-reverse sm:fixed bottom-2 right-2">
<div className="flex gap-3">
<ResetConfigsButton />
<ConfigSelector />
<RunButton />
<div className="flex flex-col gap-2">
<AdvancedGeneralSettings />
<div className="flex gap-4 items-end">
{t('Version')}: {manifest.version}
<Link onClick={() => BrowserOpenURL('https://github.com/josStorer/RWKV-Runner')}>{t('Help')}</Link>
</div>
</div>
<div className="flex gap-4 items-end">
{t('Version')}: {manifest.version}
<Link onClick={() => BrowserOpenURL('https://github.com/josStorer/RWKV-Runner')}>{t('Help')}</Link>
</div>
)
: (
<div className="flex flex-col justify-between h-full">
<img className="rounded-xl select-none object-cover hidden sm:block"
style={{ maxHeight: '40%' }} src={banner} />
<div className="flex flex-col gap-2">
<Text size={600} weight="medium">{t('Introduction')}</Text>
<div className="h-40 overflow-y-auto overflow-x-hidden p-1">
<LazyImportComponent lazyChildren={MarkdownRender}>
{lang in commonStore.introduction ? commonStore.introduction[lang] : commonStore.introduction['en']}
</LazyImportComponent>
</div>
</div>
<div className="grid grid-cols-2 sm:grid-cols-4 gap-5">
{clientNavCards.map(({ label, path, icon, desc }, index) => (
<CompoundButton icon={icon} secondaryContent={t(desc)} key={`${path}-${index}`} value={path}
size="large" onClick={() => onClickNavCard(path)}>
{t(label)}
</CompoundButton>
))}
</div>
<div className="flex flex-col gap-2">
<div className="flex flex-row-reverse sm:fixed bottom-2 right-2">
<div className="flex gap-3">
<ResetConfigsButton />
<ConfigSelector />
<RunButton />
</div>
</div>
<div className="flex gap-4 items-end">
{t('Version')}: {manifest.version}
<Link onClick={() => BrowserOpenURL('https://github.com/josStorer/RWKV-Runner')}>{t('Help')}</Link>
</div>
</div>
</div>
</div>
);
);
});
export default Home;

View File

@@ -1,5 +1,7 @@
import React, { FC } from 'react';
import React, { FC, useEffect, useState } from 'react';
import {
Button,
Checkbox,
createTableColumn,
DataGrid,
DataGridBody,
@@ -19,24 +21,10 @@ import commonStore from '../stores/commonStore';
import { BrowserOpenURL } from '../../wailsjs/runtime';
import { AddToDownloadList, OpenFileFolder } from '../../wailsjs/go/backend_golang/App';
import { Page } from '../components/Page';
import { bytesToGb, refreshModels, saveConfigs, toastWithButton } from '../utils';
import { bytesToGb, getHfDownloadUrl, refreshModels, saveConfigs, toastWithButton } from '../utils';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router';
export type ModelSourceItem = {
name: string;
size: number;
lastUpdated: string;
desc?: { [lang: string]: string | undefined; };
SHA256?: string;
url?: string;
downloadUrl?: string;
isComplete?: boolean;
isLocal?: boolean;
localSize?: number;
lastUpdatedMs?: number;
hide?: boolean;
};
import { ModelSourceItem } from '../types/models';
const columns: TableColumnDefinition<ModelSourceItem>[] = [
createTableColumn<ModelSourceItem>({
@@ -153,7 +141,7 @@ const columns: TableColumnDefinition<ModelSourceItem>[] = [
navigate({ pathname: '/downloads' });
},
{ autoClose: 3000 });
AddToDownloadList(`${commonStore.settings.customModelsPath}/${item.name}`, item.downloadUrl!);
AddToDownloadList(`${commonStore.settings.customModelsPath}/${item.name}`, getHfDownloadUrl(item.downloadUrl!));
}} />}
{item.url && <ToolTipButton desc={t('Open Url')} icon={<Open20Regular />} onClick={() => {
BrowserOpenURL(item.url!);
@@ -165,8 +153,24 @@ const columns: TableColumnDefinition<ModelSourceItem>[] = [
})
];
export const Models: FC = observer(() => {
const Models: FC = observer(() => {
const { t } = useTranslation();
const [tags, setTags] = useState<Array<string>>([]);
const [modelSourceList, setModelSourceList] = useState<ModelSourceItem[]>(commonStore.modelSourceList);
useEffect(() => {
setTags(Array.from(new Set(
[...commonStore.modelSourceList.map(item => item.tags || []).flat()
.filter(i => !i.includes('Other') && !i.includes('Local'))
, 'Other', 'Local'])));
}, [commonStore.modelSourceList]);
useEffect(() => {
if (commonStore.activeModelListTags.length === 0)
setModelSourceList(commonStore.modelSourceList);
else
setModelSourceList(commonStore.modelSourceList.filter(item => commonStore.activeModelListTags.some(tag => item.tags?.includes(tag))));
}, [commonStore.modelSourceList, commonStore.activeModelListTags]);
return (
<Page title={t('Models')} content={
@@ -174,10 +178,21 @@ export const Models: FC = observer(() => {
<div className="flex flex-col gap-1">
<div className="flex justify-between items-center">
<Text weight="medium">{t('Model Source Manifest List')}</Text>
<ToolTipButton desc={t('Refresh')} icon={<ArrowClockwise20Regular />} onClick={() => {
refreshModels(false);
saveConfigs();
}} />
<div className="flex">
{commonStore.settings.language === 'zh' &&
<Checkbox className="select-none"
size="large" label={t('Use Hugging Face Mirror')}
checked={commonStore.settings.useHfMirror}
onChange={(_, data) => {
commonStore.setSettings({
useHfMirror: data.checked as boolean
});
}} />}
<ToolTipButton desc={t('Refresh')} icon={<ArrowClockwise20Regular />} onClick={() => {
refreshModels(false);
saveConfigs();
}} />
</div>
</div>
<Text size={100}>
{t('Provide JSON file URLs for the models manifest. Separate URLs with semicolons. The "models" field in JSON files will be parsed into the following table.')}
@@ -186,9 +201,24 @@ export const Models: FC = observer(() => {
value={commonStore.modelSourceManifestList}
onChange={(e, data) => commonStore.setModelSourceManifestList(data.value)} />
</div>
<div className="flex gap-2 flex-wrap overflow-y-auto" style={{ minHeight: '88px' }}>
{tags.map(tag =>
<div key={tag} className="mt-auto">
<Button
appearance={commonStore.activeModelListTags.includes(tag) ? 'primary' : 'secondary'} onClick={
() => {
if (commonStore.activeModelListTags.includes(tag))
commonStore.setActiveModelListTags(commonStore.activeModelListTags.filter(t => t !== tag));
else
commonStore.setActiveModelListTags([...commonStore.activeModelListTags, tag]);
}
}>{t(tag)}</Button>
</div>)
}
</div>
<div className="flex grow overflow-hidden">
<DataGrid
items={commonStore.modelSourceList}
items={modelSourceList}
columns={columns}
sortable={true}
defaultSortState={{ sortColumn: 'actions', sortDirection: 'ascending' }}
@@ -220,3 +250,5 @@ export const Models: FC = observer(() => {
} />
);
});
export default Models;

View File

@@ -1,14 +1,14 @@
import React, { FC, useState } from 'react';
import { DragDropContext, Draggable, Droppable, DropResult } from 'react-beautiful-dnd';
import commonStore from '../../stores/commonStore';
import { Preset } from './PresetsButton';
import { observer } from 'mobx-react-lite';
import { v4 as uuid } from 'uuid';
import { Button, Card, Dropdown, Option, Textarea } from '@fluentui/react-components';
import { useTranslation } from 'react-i18next';
import { ToolTipButton } from '../../components/ToolTipButton';
import { Delete20Regular, ReOrderDotsVertical20Regular } from '@fluentui/react-icons';
import { ConversationMessage, Role } from '../Chat';
import { Preset } from '../../types/presets';
import { ConversationMessage, Role } from '../../types/chat';
type Item = {
id: string;
@@ -31,7 +31,7 @@ const reorder = (list: Item[], startIndex: number, endIndex: number) => {
return result;
};
export const MessagesEditor: FC = observer(() => {
const MessagesEditor: FC = observer(() => {
const { t } = useTranslation();
const editingPreset = commonStore.editingPreset!;
@@ -108,7 +108,8 @@ export const MessagesEditor: FC = observer(() => {
style={{ borderTopRightRadius: 0, borderBottomRightRadius: 0 }}>
<ReOrderDotsVertical20Regular />
</Card>
<Dropdown style={{ minWidth: 0, borderRadius: 0 }} listbox={{ style: { minWidth: 0 } }}
<Dropdown style={{ minWidth: 0, borderRadius: 0 }}
listbox={{ style: { minWidth: 'fit-content' } }}
value={t(item.role)!}
selectedOptions={[item.role]}
onOptionSelect={(_, data) => {
@@ -152,3 +153,5 @@ export const MessagesEditor: FC = observer(() => {
</div>
);
});
export default MessagesEditor;

View File

@@ -1,6 +1,6 @@
// TODO refactor
import React, { FC, PropsWithChildren, ReactElement, useState } from 'react';
import React, { FC, lazy, PropsWithChildren, ReactElement, useState } from 'react';
import {
Button,
Dialog,
@@ -25,44 +25,19 @@ import {
} from '@fluentui/react-icons';
import { ToolTipButton } from '../../components/ToolTipButton';
import { useTranslation } from 'react-i18next';
import { botName, Conversation, ConversationMessage, MessageType, userName } from '../Chat';
import { SelectTabEventHandler } from '@fluentui/react-tabs';
import { Labeled } from '../../components/Labeled';
import commonStore from '../../stores/commonStore';
import logo from '../../assets/images/logo.png';
import { observer } from 'mobx-react-lite';
import { MessagesEditor } from './MessagesEditor';
import { ClipboardGetText, ClipboardSetText } from '../../../wailsjs/runtime';
import { toast } from 'react-toastify';
import { CustomToastContainer } from '../../components/CustomToastContainer';
import { v4 as uuid } from 'uuid';
import { absPathAsset } from '../../utils';
import { absPathAsset, setActivePreset } from '../../utils';
import { Preset, PresetsNavigationItem } from '../../types/presets';
import { LazyImportComponent } from '../../components/LazyImportComponent';
export type PresetType = 'chat' | 'completion' | 'chatInCompletion'
export type Preset = {
name: string,
tag: string,
// if name and sourceUrl are same, it will be overridden when importing
sourceUrl: string,
desc: string,
avatarImg: string,
type: PresetType,
// chat
welcomeMessage: string,
messages: ConversationMessage[],
displayPresetMessages: boolean,
// completion
prompt: string,
stop: string,
injectStart: string,
injectEnd: string,
presystem?: boolean,
userName?: string,
assistantName?: string
}
export const defaultPreset: Preset = {
const defaultPreset: Preset = {
name: 'RWKV',
tag: 'default',
sourceUrl: '',
@@ -75,33 +50,17 @@ export const defaultPreset: Preset = {
prompt: '',
stop: '',
injectStart: '',
injectEnd: ''
injectEnd: '',
presystem: true,
userName: '',
assistantName: ''
};
const setActivePreset = (preset: Preset) => {
commonStore.setActivePreset(preset);
//TODO if (preset.displayPresetMessages) {
const conversation: Conversation = {};
const conversationOrder: string[] = [];
for (const message of preset.messages) {
const newUuid = uuid();
conversationOrder.push(newUuid);
conversation[newUuid] = {
sender: message.role === 'user' ? userName : botName,
type: MessageType.Normal,
color: message.role === 'user' ? 'brand' : 'colorful',
time: new Date().toISOString(),
content: message.content,
side: message.role === 'user' ? 'right' : 'left',
done: true
};
}
commonStore.setConversation(conversation);
commonStore.setConversationOrder(conversationOrder);
//}
};
const MessagesEditor = lazy(() => import('./MessagesEditor'));
export const PresetCardFrame: FC<PropsWithChildren & { onClick?: () => void }> = (props) => {
const PresetCardFrame: FC<PropsWithChildren & {
onClick?: React.MouseEventHandler<HTMLButtonElement>
}> = (props) => {
return <Button
className="flex flex-col gap-1 w-32 h-56 break-all"
style={{ minWidth: 0, borderRadius: '0.75rem', justifyContent: 'unset' }}
@@ -111,7 +70,7 @@ export const PresetCardFrame: FC<PropsWithChildren & { onClick?: () => void }> =
</Button>;
};
export const PresetCard: FC<{
const PresetCard: FC<{
avatarImg: string,
name: string,
desc: string,
@@ -124,7 +83,10 @@ export const PresetCard: FC<{
}) => {
const { t } = useTranslation();
return <PresetCardFrame onClick={onClick}>
return <PresetCardFrame onClick={(e) => {
if (onClick && e.currentTarget.contains(e.target as Node))
onClick();
}}>
<img src={absPathAsset(avatarImg)} className="rounded-xl select-none ml-auto mr-auto h-28" />
<Text size={400}>{name}</Text>
<Text size={200} style={{
@@ -137,7 +99,8 @@ export const PresetCard: FC<{
{editable ?
<ChatPresetEditor presetIndex={presetIndex} triggerButton={
<ToolTipButton size="small" appearance="transparent" desc={t('Edit')} icon={<Edit20Regular />}
onClick={() => {
onClick={(e) => {
e.stopPropagation();
commonStore.setEditingPreset({ ...commonStore.presets[presetIndex] });
}} />
} />
@@ -147,7 +110,7 @@ export const PresetCard: FC<{
</PresetCardFrame>;
});
export const ChatPresetEditor: FC<{
const ChatPresetEditor: FC<{
triggerButton: ReactElement,
presetIndex: number
}> = observer(({ triggerButton, presetIndex }) => {
@@ -291,7 +254,7 @@ export const ChatPresetEditor: FC<{
});
}} />
} />
<MessagesEditor />
<LazyImportComponent lazyChildren={MessagesEditor} />
</div> :
<div className="flex flex-col gap-1 p-2 overflow-x-hidden overflow-y-auto">
<Labeled flex breakline label={`${t('Description')} (${t('Preview Only')})`}
@@ -356,7 +319,7 @@ export const ChatPresetEditor: FC<{
</Dialog>;
});
export const ChatPresets: FC = observer(() => {
const ChatPresets: FC = observer(() => {
const { t } = useTranslation();
return <div className="flex flex-wrap gap-2">
@@ -392,12 +355,9 @@ export const ChatPresets: FC = observer(() => {
</div>;
});
type PresetsNavigationItem = {
icon: ReactElement;
element: ReactElement;
};
const pages: { [label: string]: PresetsNavigationItem } = {
const pages: {
[label: string]: PresetsNavigationItem
} = {
Chat: {
icon: <Chat20Regular />,
element: <ChatPresets />
@@ -412,7 +372,9 @@ const pages: { [label: string]: PresetsNavigationItem } = {
}
};
export const PresetsManager: FC<{ initTab: string }> = ({ initTab }) => {
const PresetsManager: FC<{
initTab: string
}> = ({ initTab }) => {
const { t } = useTranslation();
const [tab, setTab] = useState(initTab);

View File

@@ -16,33 +16,191 @@ import { observer } from 'mobx-react-lite';
import { useTranslation } from 'react-i18next';
import { checkUpdate, toastWithButton } from '../utils';
import { RestartApp } from '../../wailsjs/go/backend_golang/App';
import { Language, Languages } from '../types/settings';
export const Languages = {
dev: 'English', // i18n default
zh: '简体中文',
ja: '日本語'
};
export const GeneralSettings: FC = observer(() => {
const { t } = useTranslation();
export type Language = keyof typeof Languages;
return <div className="flex flex-col gap-2">
<Labeled label={t('Language')} flex spaceBetween content={
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 'fit-content' } }}
value={Languages[commonStore.settings.language]}
selectedOptions={[commonStore.settings.language]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
const lang = data.optionValue as Language;
commonStore.setSettings({
language: lang
});
}
}}>
{
Object.entries(Languages).map(([langKey, desc]) =>
<Option key={langKey} value={langKey}>{desc}</Option>)
}
</Dropdown>
} />
{
commonStore.platform === 'windows' &&
<Labeled label={t('DPI Scaling')} flex spaceBetween content={
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 'fit-content' } }}
value={commonStore.settings.dpiScaling + '%'}
selectedOptions={[commonStore.settings.dpiScaling.toString()]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
commonStore.setSettings({
dpiScaling: Number(data.optionValue)
});
toastWithButton(t('Restart the app to apply DPI Scaling.'), t('Restart'), () => {
RestartApp();
}, {
autoClose: 5000
});
}
}}>
{
Array.from({ length: 7 }, (_, i) => (i + 2) * 25).map((v, i) =>
<Option key={i} value={v.toString()}>{v + '%'}</Option>)
}
</Dropdown>
} />
}
<Labeled label={t('Dark Mode')} flex spaceBetween content={
<Switch checked={commonStore.settings.darkMode}
onChange={(e, data) => {
commonStore.setSettings({
darkMode: data.checked
});
}} />
} />
</div>;
});
export type SettingsType = {
language: Language
darkMode: boolean
autoUpdatesCheck: boolean
giteeUpdatesSource: boolean
cnMirror: boolean
host: string
dpiScaling: number
customModelsPath: string
customPythonPath: string
apiUrl: string
apiKey: string
apiChatModelName: string
apiCompletionModelName: string
}
export const AdvancedGeneralSettings: FC = observer(() => {
const { t } = useTranslation();
export const Settings: FC = observer(() => {
const { t, i18n } = useTranslation();
return <div className="flex flex-col gap-2">
<Labeled label={'API URL'}
content={
<div className="flex gap-2">
<Input style={{ minWidth: 0 }} className="grow" value={commonStore.settings.apiUrl}
onChange={(e, data) => {
commonStore.setSettings({
apiUrl: data.value
});
}} />
<Dropdown style={{ minWidth: '33px' }} listbox={{ style: { minWidth: 'fit-content' } }}
value="..." selectedOptions={[]} expandIcon={null}
onOptionSelect={(_, data) => {
commonStore.setSettings({
apiUrl: data.optionValue
});
if (data.optionText === 'OpenAI') {
if (commonStore.settings.apiChatModelName === 'rwkv')
commonStore.setSettings({
apiChatModelName: 'gpt-3.5-turbo'
});
if (commonStore.settings.apiCompletionModelName === 'rwkv')
commonStore.setSettings({
apiCompletionModelName: 'gpt-3.5-turbo-instruct'
});
} else if (data.optionText === 'RWKV') {
if (commonStore.settings.apiChatModelName === 'gpt-3.5-turbo')
commonStore.setSettings({
apiChatModelName: 'rwkv'
});
if (commonStore.settings.apiCompletionModelName === 'gpt-3.5-turbo-instruct' || commonStore.settings.apiCompletionModelName === 'text-davinci-003')
commonStore.setSettings({
apiCompletionModelName: 'rwkv'
});
}
}}>
<Option value="">{t('Localhost')!}</Option>
<Option value="https://rwkv.ai-creator.net/chntuned">RWKV</Option>
<Option value="https://api.openai.com">OpenAI</Option>
</Dropdown>
</div>
} />
<Labeled label={'API Key'}
content={
<Input type="password" className="grow" placeholder="sk-" value={commonStore.settings.apiKey}
onChange={(e, data) => {
commonStore.setSettings({
apiKey: data.value
});
}} />
} />
<Labeled label={t('API Chat Model Name')}
content={
<div className="flex gap-2">
<Input style={{ minWidth: 0 }} className="grow" placeholder="rwkv"
value={commonStore.settings.apiChatModelName}
onChange={(e, data) => {
commonStore.setSettings({
apiChatModelName: data.value
});
}} />
<Dropdown style={{ minWidth: '33px' }} listbox={{ style: { minWidth: 'fit-content' } }}
value="..." selectedOptions={[]} expandIcon={null}
onOptionSelect={(_, data) => {
if (data.optionValue) {
commonStore.setSettings({
apiChatModelName: data.optionValue
});
}
}}>
{
['rwkv', 'gpt-4-1106-preview', 'gpt-4', 'gpt-4-32k', 'gpt-4-0613', 'gpt-4-32k-0613', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']
.map((v, i) =>
<Option key={i} value={v}>{v}</Option>
)
}
</Dropdown>
</div>
} />
<Labeled label={t('API Completion Model Name')}
content={
<div className="flex gap-2">
<Input style={{ minWidth: 0 }} className="grow" placeholder="rwkv"
value={commonStore.settings.apiCompletionModelName}
onChange={(e, data) => {
commonStore.setSettings({
apiCompletionModelName: data.value
});
}} />
<Dropdown style={{ minWidth: '33px' }} listbox={{ style: { minWidth: 'fit-content', minHeight: 0 } }}
value="..." selectedOptions={[]} expandIcon={null}
onOptionSelect={(_, data) => {
if (data.optionValue) {
commonStore.setSettings({
apiCompletionModelName: data.optionValue
});
}
}}>
{
['rwkv', 'gpt-3.5-turbo-instruct', 'text-davinci-003', 'text-davinci-002', 'code-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001']
.map((v, i) =>
<Option key={i} value={v}>{v}</Option>
)
}
</Dropdown>
</div>
} />
<Labeled label={t('Core API URL')}
desc={t('Override core API URL(/chat/completions and /completions). If you don\'t know what this is, leave it blank.')}
content={
<Input style={{ minWidth: 0 }} className="grow" value={commonStore.settings.coreApiUrl}
onChange={(e, data) => {
commonStore.setSettings({
coreApiUrl: data.value
});
}} />
} />
</div>;
});
const Settings: FC = observer(() => {
const { t } = useTranslation();
const advancedHeaderRef = useRef<HTMLDivElement>(null);
useEffect(() => {
@@ -53,227 +211,101 @@ export const Settings: FC = observer(() => {
return (
<Page title={t('Settings')} content={
<div className="flex flex-col gap-2 overflow-y-auto overflow-x-hidden p-1">
<Labeled label={t('Language')} flex spaceBetween content={
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 0 } }}
value={Languages[commonStore.settings.language]}
selectedOptions={[commonStore.settings.language]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
const lang = data.optionValue as Language;
commonStore.setSettings({
language: lang
});
}
}}>
{
Object.entries(Languages).map(([langKey, desc]) =>
<Option key={langKey} value={langKey}>{desc}</Option>)
}
</Dropdown>
} />
{
commonStore.platform === 'windows' &&
<Labeled label={t('DPI Scaling')} flex spaceBetween content={
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 0 } }}
value={commonStore.settings.dpiScaling + '%'}
selectedOptions={[commonStore.settings.dpiScaling.toString()]}
onOptionSelect={(_, data) => {
if (data.optionValue) {
commonStore.setSettings({
dpiScaling: Number(data.optionValue)
});
toastWithButton(t('Restart the app to apply DPI Scaling.'), t('Restart'), () => {
RestartApp();
}, {
autoClose: 5000
});
}
}}>
{
Array.from({ length: 7 }, (_, i) => (i + 2) * 25).map((v, i) =>
<Option key={i} value={v.toString()}>{v + '%'}</Option>)
}
</Dropdown>
} />
}
<Labeled label={t('Dark Mode')} flex spaceBetween content={
<Switch checked={commonStore.settings.darkMode}
onChange={(e, data) => {
commonStore.setSettings({
darkMode: data.checked
});
}} />
} />
<Labeled label={t('Automatic Updates Check')} flex spaceBetween content={
<Switch checked={commonStore.settings.autoUpdatesCheck}
onChange={(e, data) => {
commonStore.setSettings({
autoUpdatesCheck: data.checked
});
if (data.checked)
checkUpdate(true);
}} />
} />
{
commonStore.settings.language === 'zh' &&
<Labeled label={t('Use Gitee Updates Source')} flex spaceBetween content={
<Switch checked={commonStore.settings.giteeUpdatesSource}
onChange={(e, data) => {
commonStore.setSettings({
giteeUpdatesSource: data.checked
});
}} />
} />
}
{
commonStore.settings.language === 'zh' && commonStore.platform !== 'linux' &&
<Labeled label={t('Use Tsinghua Pip Mirrors')} flex spaceBetween content={
<Switch checked={commonStore.settings.cnMirror}
onChange={(e, data) => {
commonStore.setSettings({
cnMirror: data.checked
});
}} />
} />
}
<Labeled label={t('Allow external access to the API (service must be restarted)')} flex spaceBetween content={
<Switch checked={commonStore.settings.host !== '127.0.0.1'}
onChange={(e, data) => {
commonStore.setSettings({
host: data.checked ? '0.0.0.0' : '127.0.0.1'
});
}} />
} />
<Accordion collapsible openItems={!commonStore.advancedCollapsed && 'advanced'} onToggle={(e, data) => {
if (data.value === 'advanced')
commonStore.setAdvancedCollapsed(!commonStore.advancedCollapsed);
}}>
<AccordionItem value="advanced">
<AccordionHeader ref={advancedHeaderRef} size="large">{t('Advanced')}</AccordionHeader>
<AccordionPanel>
<div className="flex flex-col gap-2 overflow-hidden">
{commonStore.platform !== 'darwin' &&
<Labeled label={t('Custom Models Path')}
content={
<Input className="grow" placeholder="./models" value={commonStore.settings.customModelsPath}
onChange={(e, data) => {
commonStore.setSettings({
customModelsPath: data.value
});
}} />
} />
}
<Labeled label={t('Custom Python Path')} // if set, will not use precompiled cuda kernel
content={
<Input className="grow" placeholder="./py310/python" value={commonStore.settings.customPythonPath}
onChange={(e, data) => {
commonStore.setDepComplete(false);
commonStore.setSettings({
customPythonPath: data.value
});
}} />
} />
<Labeled label={'API URL'}
content={
<div className="flex gap-2">
<Input style={{ minWidth: 0 }} className="grow" value={commonStore.settings.apiUrl}
onChange={(e, data) => {
commonStore.setSettings({
apiUrl: data.value
});
}} />
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 0 } }}
value="..." selectedOptions={[]} expandIcon={null}
onOptionSelect={(_, data) => {
commonStore.setSettings({
apiUrl: data.optionValue
});
if (data.optionText === 'OpenAI') {
if (commonStore.settings.apiChatModelName === 'rwkv')
commonStore.setSettings({
apiChatModelName: 'gpt-3.5-turbo'
});
if (commonStore.settings.apiCompletionModelName === 'rwkv')
commonStore.setSettings({
apiCompletionModelName: 'text-davinci-003'
});
}
}}>
<Option value="">{t('Localhost')!}</Option>
<Option value="https://api.openai.com">OpenAI</Option>
</Dropdown>
</div>
} />
<Labeled label={'API Key'}
content={
<Input className="grow" placeholder="sk-" value={commonStore.settings.apiKey}
onChange={(e, data) => {
commonStore.setSettings({
apiKey: data.value
});
}} />
} />
<Labeled label={t('API Chat Model Name')}
content={
<div className="flex gap-2">
<Input style={{ minWidth: 0 }} className="grow" placeholder="rwkv"
value={commonStore.settings.apiChatModelName}
onChange={(e, data) => {
commonStore.setSettings({
apiChatModelName: data.value
});
}} />
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 0 } }}
value="..." selectedOptions={[]} expandIcon={null}
onOptionSelect={(_, data) => {
if (data.optionValue) {
commonStore.setSettings({
apiChatModelName: data.optionValue
});
}
}}>
{
['rwkv', 'gpt-4', 'gpt-4-0613', 'gpt-4-32k', 'gpt-4-32k-0613', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-16k-0613']
.map((v, i) =>
<Option key={i} value={v}>{v}</Option>
)
}
</Dropdown>
</div>
} />
<Labeled label={t('API Completion Model Name')}
content={
<div className="flex gap-2">
<Input style={{ minWidth: 0 }} className="grow" placeholder="rwkv"
value={commonStore.settings.apiCompletionModelName}
onChange={(e, data) => {
commonStore.setSettings({
apiCompletionModelName: data.value
});
}} />
<Dropdown style={{ minWidth: 0 }} listbox={{ style: { minWidth: 0 } }}
value="..." selectedOptions={[]} expandIcon={null}
onOptionSelect={(_, data) => {
if (data.optionValue) {
commonStore.setSettings({
apiCompletionModelName: data.optionValue
});
}
}}>
{
['rwkv', 'text-davinci-003', 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001']
.map((v, i) =>
<Option key={i} value={v}>{v}</Option>
)
}
</Dropdown>
</div>
} />
commonStore.platform === 'web' ?
(
<div className="flex flex-col gap-2">
<GeneralSettings />
<AdvancedGeneralSettings />
</div>
</AccordionPanel>
</AccordionItem>
</Accordion>
)
:
(
<div className="flex flex-col gap-2">
<GeneralSettings />
<Labeled label={t('Automatic Updates Check')} flex spaceBetween content={
<Switch checked={commonStore.settings.autoUpdatesCheck}
onChange={(e, data) => {
commonStore.setSettings({
autoUpdatesCheck: data.checked
});
if (data.checked)
checkUpdate(true);
}} />
} />
{
commonStore.settings.language === 'zh' &&
<Labeled label={t('Use Gitee Updates Source')} flex spaceBetween content={
<Switch checked={commonStore.settings.giteeUpdatesSource}
onChange={(e, data) => {
commonStore.setSettings({
giteeUpdatesSource: data.checked
});
}} />
} />
}
{
commonStore.settings.language === 'zh' && commonStore.platform !== 'linux' &&
<Labeled label={t('Use Tsinghua Pip Mirrors')} flex spaceBetween content={
<Switch checked={commonStore.settings.cnMirror}
onChange={(e, data) => {
commonStore.setSettings({
cnMirror: data.checked
});
}} />
} />
}
<Labeled label={t('Allow external access to the API (service must be restarted)')} flex spaceBetween
content={
<Switch checked={commonStore.settings.host !== '127.0.0.1'}
onChange={(e, data) => {
commonStore.setSettings({
host: data.checked ? '0.0.0.0' : '127.0.0.1'
});
}} />
} />
<Accordion collapsible openItems={!commonStore.advancedCollapsed && 'advanced'} onToggle={(e, data) => {
if (data.value === 'advanced')
commonStore.setAdvancedCollapsed(!commonStore.advancedCollapsed);
}}>
<AccordionItem value="advanced">
<AccordionHeader ref={advancedHeaderRef} size="large">{t('Advanced')}</AccordionHeader>
<AccordionPanel>
<div className="flex flex-col gap-2 overflow-hidden">
{commonStore.platform !== 'darwin' &&
<Labeled label={t('Custom Models Path')}
content={
<Input className="grow" placeholder="./models"
value={commonStore.settings.customModelsPath}
onChange={(e, data) => {
commonStore.setSettings({
customModelsPath: data.value
});
}} />
} />
}
<Labeled label={t('Custom Python Path')} // if set, will not use precompiled cuda kernel
content={
<Input className="grow" placeholder="./py310/python"
value={commonStore.settings.customPythonPath}
onChange={(e, data) => {
commonStore.setDepComplete(false);
commonStore.setSettings({
customPythonPath: data.value
});
}} />
} />
<AdvancedGeneralSettings />
</div>
</AccordionPanel>
</AccordionItem>
</Accordion>
</div>
)
}
</div>
} />
);
});
export default Settings;

View File

@@ -1,4 +1,4 @@
import React, { FC, ReactElement, useEffect, useRef, useState } from 'react';
import React, { FC, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@fluentui/react-components';
import {
@@ -24,7 +24,6 @@ import { Labeled } from '../components/Labeled';
import { ToolTipButton } from '../components/ToolTipButton';
import { DataUsageSettings20Regular, Folder20Regular } from '@fluentui/react-icons';
import { useNavigate } from 'react-router';
import { Precision } from './Configs';
import {
CategoryScale,
Chart as ChartJS,
@@ -40,6 +39,12 @@ import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
import { WindowShow } from '../../wailsjs/runtime';
import { t } from 'i18next';
import { DialogButton } from '../components/DialogButton';
import {
DataProcessParameters,
LoraFinetuneParameters,
LoraFinetunePrecision,
TrainNavigationItem
} from '../types/train';
ChartJS.register(
CategoryScale,
@@ -61,6 +66,11 @@ const parseLossData = (data: string) => {
const loss = parseFloat(lastMatch[8]);
commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`);
addLossDataToChart(epoch, loss);
if (loss > 5)
toast(t('Loss is too high, please check the training data, and ensure your gpu driver is up to date.'), {
type: 'warning',
toastId: 'train_loss_high'
});
return true;
};
@@ -86,39 +96,6 @@ const addLossDataToChart = (epoch: number, loss: number) => {
commonStore.setChartData(commonStore.chartData);
};
export type DataProcessParameters = {
dataPath: string;
vocabPath: string;
}
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'tf32';
export type LoraFinetuneParameters = {
baseModel: string;
ctxLen: number;
epochSteps: number;
epochCount: number;
epochBegin: number;
epochSave: number;
microBsz: number;
accumGradBatches: number;
preFfn: boolean;
headQk: boolean;
lrInit: string;
lrFinal: string;
warmupSteps: number;
beta1: number;
beta2: number;
adamEps: string;
devices: number;
precision: LoraFinetunePrecision;
gradCp: boolean;
loraR: number;
loraAlpha: number;
loraDropout: number;
loraLoad: string
}
const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, type: string, name: string]> = [
['devices', 'number', 'Devices'],
['precision', 'LoraFinetunePrecision', 'Precision'],
@@ -157,11 +134,13 @@ const errorsMap = Object.entries({
'python3 ./finetune/lora/train.py': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
'cuda out of memory': 'VRAM is not enough',
'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
'+= \'+ptx\'': 'You are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
'+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
'size mismatch for blocks': 'Size mismatch for blocks. You are attempting to continue training from the LoRA model, but it does not match the base model. Please set LoRA model to None.',
'cuda_home environment variable is not set': 'Matched CUDA is not installed',
'unsupported gpu architecture': 'Matched CUDA is not installed',
'error building extension \'fused_adam\'': 'Matched CUDA is not installed'
'error building extension \'fused_adam\'': 'Matched CUDA is not installed',
'rwkv{version} is not supported': 'This version of RWKV is not supported yet.',
'modelinfo is invalid': 'Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.'
});
export const wslHandler = (data: string) => {
@@ -491,11 +470,12 @@ const LoraFinetune: FC = observer(() => {
return;
if (loraParams.loraLoad) {
const outputPath = `models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`;
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
MergeLora(commonStore.settings.customPythonPath, !!commonStore.monitorData && commonStore.monitorData.totalVram !== 0, loraParams.loraAlpha,
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
outputPath).then(async () => {
if (!await FileExists(outputPath)) {
toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' });
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(t('Merge model successfully'), { type: 'success' });
}
@@ -568,10 +548,6 @@ const LoraFinetune: FC = observer(() => {
);
});
type TrainNavigationItem = {
element: ReactElement;
};
const pages: { [label: string]: TrainNavigationItem } = {
'LoRA Finetune': {
element: <LoraFinetune />
@@ -582,7 +558,7 @@ const pages: { [label: string]: TrainNavigationItem } = {
};
export const Train: FC = () => {
const Train: FC = () => {
const { t } = useTranslation();
const [tab, setTab] = useState('LoRA Finetune');
@@ -607,3 +583,5 @@ export const Train: FC = () => {
</div>
</div>;
};
export default Train;

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
import { ReactElement } from 'react';
import { Configs } from './Configs';
import { FC, lazy, LazyExoticComponent, ReactElement } from 'react';
import {
ArrowDownload20Regular,
Chat20Regular,
@@ -12,21 +11,12 @@ import {
Settings20Regular,
Storage20Regular
} from '@fluentui/react-icons';
import { Home } from './Home';
import { Chat } from './Chat';
import { Models } from './Models';
import { Train } from './Train';
import { Settings } from './Settings';
import { About } from './About';
import { Downloads } from './Downloads';
import { Completion } from './Completion';
import { Composition } from './Composition';
type NavigationItem = {
label: string;
path: string;
icon: ReactElement;
element: ReactElement;
element: LazyExoticComponent<FC>;
top: boolean;
};
@@ -35,70 +25,70 @@ export const pages: NavigationItem[] = [
label: 'Home',
path: '/',
icon: <Home20Regular />,
element: <Home />,
element: lazy(() => import('./Home')),
top: true
},
{
label: 'Chat',
path: '/chat',
icon: <Chat20Regular />,
element: <Chat />,
element: lazy(() => import('./Chat')),
top: true
},
{
label: 'Completion',
path: '/completion',
icon: <ClipboardEdit20Regular />,
element: <Completion />,
element: lazy(() => import('./Completion')),
top: true
},
{
label: 'Composition',
path: '/composition',
icon: <MusicNote220Regular />,
element: <Composition />,
element: lazy(() => import('./Composition')),
top: true
},
{
label: 'Configs',
path: '/configs',
icon: <DocumentSettings20Regular />,
element: <Configs />,
element: lazy(() => import('./Configs')),
top: true
},
{
label: 'Models',
path: '/models',
icon: <DataUsageSettings20Regular />,
element: <Models />,
element: lazy(() => import('./Models')),
top: true
},
{
label: 'Downloads',
path: '/downloads',
icon: <ArrowDownload20Regular />,
element: <Downloads />,
element: lazy(() => import('./Downloads')),
top: true
},
{
label: 'Train',
path: '/train',
icon: <Storage20Regular />,
element: <Train />,
element: lazy(() => import('./Train')),
top: true
},
{
label: 'Settings',
path: '/settings',
icon: <Settings20Regular />,
element: <Settings />,
element: lazy(() => import('./Settings')),
top: false
},
{
label: 'About',
path: '/about',
icon: <Info20Regular />,
element: <About />,
element: lazy(() => import('./About')),
top: false
}
];

View File

@@ -1,43 +1,49 @@
import commonStore, { Platform } from './stores/commonStore';
import { GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
import commonStore, { MonitorData, Platform } from './stores/commonStore';
import { FileExists, GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshLocalModels, refreshModels } from './utils';
import { getStatus } from './apis';
import { EventsOn, WindowSetTitle } from '../wailsjs/runtime';
import manifest from '../../manifest.json';
import { defaultModelConfigs, defaultModelConfigsMac } from './pages/defaultConfigs';
import { Preset } from './pages/PresetsManager/PresetsButton';
import { wslHandler } from './pages/Train';
import { t } from 'i18next';
import { Preset } from './types/presets';
import { toast } from 'react-toastify';
import { MidiMessage, MidiPort } from './types/composition';
export async function startup() {
downloadProgramFiles();
EventsOn('downloadList', (data) => {
if (data)
commonStore.setDownloadList(data);
});
EventsOn('wsl', wslHandler);
EventsOn('wslerr', (e) => {
console.log(e);
});
initLocalModelsNotify();
initLoraModels();
initPresets();
initHardwareMonitor();
await GetPlatform().then(p => commonStore.setPlatform(p as Platform));
if (commonStore.platform !== 'web') {
downloadProgramFiles();
EventsOn('downloadList', (data) => {
if (data)
commonStore.setDownloadList(data);
});
EventsOn('wsl', (await import('./pages/Train')).wslHandler);
EventsOn('wslerr', (e) => {
console.log(e);
});
initLocalModelsNotify();
initLoraModels();
initHardwareMonitor();
initMidi();
}
await initConfig();
initCache(true).then(initRemoteText); // depends on config customModelsPath
if (commonStore.platform !== 'web') {
initCache(true).then(initRemoteText); // depends on config customModelsPath
if (commonStore.settings.autoUpdatesCheck) // depends on config settings
checkUpdate();
if (commonStore.settings.autoUpdatesCheck) // depends on config settings
checkUpdate();
getStatus(1000).then(status => { // depends on config api port
if (status)
commonStore.setStatus(status);
});
getStatus(1000).then(status => { // depends on config api port
if (status)
commonStore.setStatus(status);
});
}
}
async function initRemoteText() {
@@ -88,7 +94,8 @@ async function initCache(initUnfinishedModels: boolean) {
async function initPresets() {
await ReadJson('presets.json').then((presets: Preset[]) => {
commonStore.setPresets(presets, false);
if (Array.isArray(presets))
commonStore.setPresets(presets, false);
}).catch(() => {
});
}
@@ -121,19 +128,31 @@ async function initLocalModelsNotify() {
});
}
type monitorData = {
usedMemory: number;
totalMemory: number;
gpuUsage: number;
gpuPower: number;
usedVram: number;
totalVram: number;
}
async function initHardwareMonitor() {
EventsOn('monitor', (data: string) => {
const results: monitorData = JSON.parse(data);
if (results)
const results: MonitorData = JSON.parse(data);
if (results) {
commonStore.setMonitorData(results);
WindowSetTitle(`RWKV-Runner (${t('RAM')}: ${results.usedMemory.toFixed(1)}/${results.totalMemory.toFixed(1)} GB, ${t('VRAM')}: ${(results.usedVram / 1024).toFixed(1)}/${(results.totalVram / 1024).toFixed(1)} GB, ${t('GPU Usage')}: ${results.gpuUsage}%)`);
}
});
}
async function initMidi() {
EventsOn('midiError', (data: string) => {
if (commonStore.platform === 'windows')
toast('MIDI Error: ' + data, { type: 'error' });
});
EventsOn('midiPorts', (data: MidiPort[]) => {
commonStore.setMidiPorts(data);
});
EventsOn('midiMessage', async (data: MidiMessage) => {
await (await import('./pages/AudiotrackManager/AudiotrackEditor')).midiMessageHandler(data);
});
if (await FileExists('assets/sound-font/accordion/instrument.json')) {
commonStore.setCompositionParams({
...commonStore.compositionParams,
useLocalSoundFont: true
});
}
}

View File

@@ -2,21 +2,27 @@ import { makeAutoObservable } from 'mobx';
import { getUserLanguage, isSystemLightMode, saveCache, saveConfigs, savePresets } from '../utils';
import { WindowSetDarkTheme, WindowSetLightTheme } from '../../wailsjs/runtime';
import manifest from '../../../manifest.json';
import { ModelConfig } from '../pages/Configs';
import { Conversation } from '../pages/Chat';
import { ModelSourceItem } from '../pages/Models';
import { DownloadStatus } from '../pages/Downloads';
import { SettingsType } from '../pages/Settings';
import { IntroductionContent } from '../pages/Home';
import { AboutContent } from '../pages/About';
import i18n from 'i18next';
import { CompletionPreset } from '../pages/Completion';
import { defaultCompositionPrompt, defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultConfigs';
import commonStore from './commonStore';
import { Preset } from '../pages/PresetsManager/PresetsButton';
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
import { ChartData } from 'chart.js';
import { CompositionParams } from '../pages/Composition';
import { Preset } from '../types/presets';
import { AboutContent } from '../types/about';
import { Attachment, ChatParams, Conversation } from '../types/chat';
import { CompletionPreset } from '../types/completion';
import {
CompositionParams,
InstrumentType,
MidiMessage,
MidiPort,
Track,
tracksMinimalTotalTime
} from '../types/composition';
import { ModelConfig } from '../types/configs';
import { DownloadStatus } from '../types/downloads';
import { IntroductionContent } from '../types/home';
import { ModelSourceItem } from '../types/models';
import { SettingsType } from '../types/settings';
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
export enum ModelStatus {
Offline,
@@ -31,9 +37,16 @@ export type Status = {
device_name: string;
}
export type Platform = 'windows' | 'darwin' | 'linux';
export type MonitorData = {
usedMemory: number;
totalMemory: number;
gpuUsage: number;
gpuPower: number;
usedVram: number;
totalVram: number;
}
const labels = ['January', 'February', 'March', 'April', 'May', 'June', 'July'];
export type Platform = 'windows' | 'darwin' | 'linux' | 'web';
class CommonStore {
// global
@@ -42,8 +55,10 @@ class CommonStore {
pid: 0,
device_name: 'CPU'
};
monitorData: MonitorData | null = null;
depComplete: boolean = false;
platform: Platform = 'windows';
lastModelName: string = '';
// presets manager
editingPreset: Preset | null = null;
presets: Preset[] = [];
@@ -55,9 +70,18 @@ class CommonStore {
conversationOrder: string[] = [];
activePreset: Preset | null = null;
attachmentUploading: boolean = false;
attachmentName: string = '';
attachmentSize: number = 0;
attachmentContent: string = '';
attachments: {
[uuid: string]: Attachment[]
} = {};
currentTempAttachment: Attachment | null = null;
chatParams: ChatParams = {
maxResponseToken: 1000,
temperature: 1,
topP: 0.3,
presencePenalty: 0,
frequencyPenalty: 1
};
sidePanelCollapsed: boolean | 'auto' = 'auto';
// completion
completionPreset: CompletionPreset | null = null;
completionGenerating: boolean = false;
@@ -70,16 +94,32 @@ class CommonStore {
topP: 0.8,
autoPlay: true,
useLocalSoundFont: false,
externalPlay: false,
midi: null,
ns: null
};
compositionGenerating: boolean = false;
compositionSubmittedPrompt: string = defaultCompositionPrompt;
// composition midi device
midiPorts: MidiPort[] = [];
activeMidiDeviceIndex: number = -1;
instrumentType: InstrumentType = InstrumentType.Piano;
// composition tracks
tracks: Track[] = [];
trackScale: number = 1;
trackTotalTime: number = tracksMinimalTotalTime;
trackCurrentTime: number = 0;
trackPlayStartTime: number = 0;
playingTrackId: string = '';
recordingTrackId: string = '';
recordingContent: string = ''; // used to improve performance of midiMessageHandler, and I'm too lazy to maintain an ID dictionary for this (although that would be better for realtime effects)
recordingRawContent: MidiMessage[] = [];
// configs
currentModelConfigIndex: number = 0;
modelConfigs: ModelConfig[] = [];
modelParamsCollapsed: boolean = true;
// models
activeModelListTags: string[] = [];
modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;';
modelSourceList: ModelSourceItem[] = [];
// downloads
@@ -100,7 +140,7 @@ class CommonStore {
epochSteps: 200,
epochCount: 20,
epochBegin: 0,
epochSave: 2,
epochSave: 1,
microBsz: 1,
accumGradBatches: 8,
preFfn: false,
@@ -127,14 +167,16 @@ class CommonStore {
autoUpdatesCheck: true,
giteeUpdatesSource: getUserLanguage() === 'zh',
cnMirror: getUserLanguage() === 'zh',
useHfMirror: false,
host: '127.0.0.1',
dpiScaling: 100,
customModelsPath: './models',
customPythonPath: '',
apiUrl: '',
apiKey: 'sk-',
apiKey: '',
apiChatModelName: 'rwkv',
apiCompletionModelName: 'rwkv'
apiCompletionModelName: 'rwkv',
coreApiUrl: ''
};
// about
about: AboutContent = manifest.about;
@@ -172,7 +214,7 @@ class CommonStore {
createModelConfig = (config: ModelConfig = defaultModelConfigs[0], saveConfig: boolean = true) => {
if (config.name === defaultModelConfigs[0].name) {
// deep copy
config = JSON.parse(JSON.stringify(commonStore.platform !== 'darwin' ? defaultModelConfigs[0] : defaultModelConfigsMac[0]));
config = JSON.parse(JSON.stringify(this.platform !== 'darwin' ? defaultModelConfigs[0] : defaultModelConfigsMac[0]));
config.name = new Date().toLocaleString();
}
this.modelConfigs.push(config);
@@ -226,6 +268,10 @@ class CommonStore {
this.about = value;
};
setLastModelName(value: string) {
this.lastModelName = value;
}
setDepComplete = (value: boolean, inSaveCache: boolean = true) => {
this.depComplete = value;
if (inSaveCache)
@@ -252,6 +298,10 @@ class CommonStore {
this.completionGenerating = value;
}
setMonitorData(value: MonitorData) {
this.monitorData = value;
}
setPlatform(value: Platform) {
this.platform = value;
}
@@ -282,7 +332,7 @@ class CommonStore {
savePresets();
}
setActivePreset(value: Preset) {
setActivePreset(value: Preset | null) {
this.activePreset = value;
}
@@ -334,16 +384,81 @@ class CommonStore {
this.attachmentUploading = value;
}
setAttachmentName(value: string) {
this.attachmentName = value;
setAttachments(value: {
[uuid: string]: Attachment[]
}) {
this.attachments = value;
}
setAttachmentSize(value: number) {
this.attachmentSize = value;
setAttachment(uuid: string, value: Attachment[] | null) {
if (value === null)
delete this.attachments[uuid];
else
this.attachments[uuid] = value;
}
setAttachmentContent(value: string) {
this.attachmentContent = value;
setCurrentTempAttachment(value: Attachment | null) {
this.currentTempAttachment = value;
}
setChatParams(value: Partial<ChatParams>) {
this.chatParams = { ...this.chatParams, ...value };
}
setSidePanelCollapsed(value: boolean | 'auto') {
this.sidePanelCollapsed = value;
}
setTracks(value: Track[]) {
this.tracks = value;
}
setTrackScale(value: number) {
this.trackScale = value;
}
setTrackTotalTime(value: number) {
this.trackTotalTime = value;
}
setTrackCurrentTime(value: number) {
this.trackCurrentTime = value;
}
setTrackPlayStartTime(value: number) {
this.trackPlayStartTime = value;
}
setMidiPorts(value: MidiPort[]) {
this.midiPorts = value;
}
setInstrumentType(value: InstrumentType) {
this.instrumentType = value;
}
setRecordingTrackId(value: string) {
this.recordingTrackId = value;
}
setActiveMidiDeviceIndex(value: number) {
this.activeMidiDeviceIndex = value;
}
setRecordingContent(value: string) {
this.recordingContent = value;
}
setRecordingRawContent(value: MidiMessage[]) {
this.recordingRawContent = value;
}
setPlayingTrackId(value: string) {
this.playingTrackId = value;
}
setActiveModelListTags(value: string[]) {
this.activeModelListTags = value;
}
}

View File

@@ -1,12 +1,10 @@
[data-theme='dark'] {
@import 'highlight.js/scss/github-dark.scss';
@import 'github-markdown-css/github-markdown-dark.css';
--color-neutral-muted: rgba(110, 118, 129, 0.4);
}
[data-theme='light'] {
@import 'highlight.js/scss/github.scss';
@import 'github-markdown-css/github-markdown-light.css';
--color-neutral-muted: rgba(150, 160, 170, 0.3);
}
@@ -16,7 +14,6 @@
body {
margin: 0;
overflow: hidden;
width: 100%;
height: 100%;
}

View File

@@ -0,0 +1 @@
export type AboutContent = { [lang: string]: string }

View File

@@ -0,0 +1,37 @@
import { ApiParameters } from './configs';
export const userName = 'M E';
export const botName = 'A I';
export const welcomeUuid = 'welcome';
export enum MessageType {
Normal,
Error
}
export type Side = 'left' | 'right'
export type Color = 'neutral' | 'brand' | 'colorful'
export type MessageItem = {
sender: string,
type: MessageType,
color: Color,
avatarImg?: string,
time: string,
content: string,
side: Side,
done: boolean
}
export type Conversation = {
[uuid: string]: MessageItem
}
export type Role = 'assistant' | 'user' | 'system';
export type ConversationMessage = {
role: Role;
content: string;
}
export type Attachment = {
name: string;
size: number;
content: string;
}
export type ChatParams = Omit<ApiParameters, 'apiPort'>

View File

@@ -0,0 +1,12 @@
import { ApiParameters } from './configs';
export type CompletionParams = Omit<ApiParameters, 'apiPort'> & {
stop: string,
injectStart: string,
injectEnd: string
};
export type CompletionPreset = {
name: string,
prompt: string,
params: CompletionParams
}

View File

@@ -0,0 +1,86 @@
import { NoteSequence } from '@magenta/music/esm/protobuf';
export const tracksMinimalTotalTime = 5000;
export type CompositionParams = {
prompt: string,
maxResponseToken: number,
temperature: number,
topP: number,
autoPlay: boolean,
useLocalSoundFont: boolean,
externalPlay: boolean,
midi: ArrayBuffer | null,
ns: NoteSequence | null
}
export type Track = {
id: string;
mainInstrument: string;
content: string;
rawContent: MidiMessage[];
offsetTime: number;
contentTime: number;
};
export type MidiPort = {
name: string;
}
export type MessageType = 'NoteOff' | 'NoteOn' | 'ElapsedTime' | 'ControlChange';
export type MidiMessage = {
messageType: MessageType;
channel: number;
note: number;
velocity: number;
control: number;
value: number;
instrument: InstrumentType;
}
export enum InstrumentType {
Piano,
Percussion,
Drum,
Tuba,
Marimba,
Bass,
Guitar,
Violin,
Trumpet,
Sax,
Flute,
Lead,
Pad,
}
export const InstrumentTypeNameMap = [
'Piano',
'Percussion',
'Drum',
'Tuba',
'Marimba',
'Bass',
'Guitar',
'Violin',
'Trumpet',
'Sax',
'Flute',
'Lead',
'Pad'
];
export const InstrumentTypeTokenMap = [
'pi',
'p',
'd',
't',
'm',
'b',
'g',
'v',
'tr',
's',
'f',
'l',
'pa'
];

View File

@@ -0,0 +1,29 @@
export type ApiParameters = {
apiPort: number
maxResponseToken: number;
temperature: number;
topP: number;
presencePenalty: number;
frequencyPenalty: number;
}
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
export type ModelParameters = {
// different models can not have the same name
modelName: string;
device: Device;
precision: Precision;
storedLayers: number;
maxStoredLayers: number;
useCustomCuda?: boolean;
customStrategy?: string;
useCustomTokenizer?: boolean;
customTokenizer?: string;
}
export type ModelConfig = {
// different configs can have the same name
name: string;
apiParameters: ApiParameters;
modelParameters: ModelParameters;
enableWebUI?: boolean;
}

View File

@@ -0,0 +1,11 @@
export type DownloadStatus = {
name: string;
path: string;
url: string;
transferred: number;
size: number;
speed: number;
progress: number;
downloading: boolean;
done: boolean;
}

View File

@@ -0,0 +1,11 @@
import { ReactElement } from 'react';
export type IntroductionContent = {
[lang: string]: string
}
export type NavCard = {
label: string;
desc: string;
path: string;
icon: ReactElement;
};

View File

@@ -0,0 +1,15 @@
export type ModelSourceItem = {
name: string;
size: number;
lastUpdated: string;
desc?: { [lang: string]: string | undefined; };
SHA256?: string;
url?: string;
downloadUrl?: string;
isComplete?: boolean;
isLocal?: boolean;
localSize?: number;
lastUpdatedMs?: number;
tags?: string[];
hide?: boolean;
};

View File

@@ -0,0 +1,30 @@
import { ReactElement } from 'react';
import { ConversationMessage } from './chat';
export type PresetType = 'chat' | 'completion' | 'chatInCompletion'
export type Preset = {
name: string,
tag: string,
// if name and sourceUrl are same, it will be overridden when importing
sourceUrl: string,
desc: string,
avatarImg: string,
type: PresetType,
// chat
welcomeMessage: string,
messages: ConversationMessage[],
displayPresetMessages: boolean,
// completion
prompt: string,
stop: string,
injectStart: string,
injectEnd: string,
presystem?: boolean,
userName?: string,
assistantName?: string
}
export type PresetsNavigationItem = {
icon: ReactElement;
element: ReactElement;
};

View File

@@ -0,0 +1,23 @@
export const Languages = {
dev: 'English', // i18n default
zh: '简体中文',
ja: '日本語'
};
export type Language = keyof typeof Languages;
export type SettingsType = {
language: Language
darkMode: boolean
autoUpdatesCheck: boolean
giteeUpdatesSource: boolean
cnMirror: boolean
useHfMirror: boolean
host: string
dpiScaling: number
customModelsPath: string
customPythonPath: string
apiUrl: string
apiKey: string
apiChatModelName: string
apiCompletionModelName: string
coreApiUrl: string
}

View File

@@ -0,0 +1,35 @@
import { ReactElement } from 'react';
export type DataProcessParameters = {
dataPath: string;
vocabPath: string;
}
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'tf32';
export type LoraFinetuneParameters = {
baseModel: string;
ctxLen: number;
epochSteps: number;
epochCount: number;
epochBegin: number;
epochSave: number;
microBsz: number;
accumGradBatches: number;
preFfn: boolean;
headQk: boolean;
lrInit: string;
lrFinal: string;
warmupSteps: number;
beta1: number;
beta2: number;
adamEps: string;
devices: number;
precision: LoraFinetunePrecision;
gradCp: boolean;
loraR: number;
loraAlpha: number;
loraDropout: number;
loraLoad: string
}
export type TrainNavigationItem = {
element: ReactElement;
};

View File

@@ -0,0 +1,107 @@
import { toast } from 'react-toastify';
import commonStore from '../stores/commonStore';
import { t } from 'i18next';
import {
ConvertGGML,
ConvertModel,
ConvertSafetensors,
FileExists,
GetPyError
} from '../../wailsjs/go/backend_golang/App';
import { WindowShow } from '../../wailsjs/runtime';
import { ModelConfig, Precision } from '../types/configs';
import { checkDependencies, getStrategy } from './index';
import { NavigateFunction } from 'react-router';
export const convertModel = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
if (commonStore.platform === 'darwin') {
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
return;
} else if (commonStore.platform === 'linux') {
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
return;
}
const ok = await checkDependencies(navigate);
if (!ok)
return;
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
if (await FileExists(modelPath)) {
const strategy = getStrategy(selectedConfig);
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
if (!await FileExists(newModelPath + '.pth')) {
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
}
}).catch(e => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
else
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
});
setTimeout(WindowShow, 1000);
} else {
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
}
};
export const convertToSt = async (selectedConfig: ModelConfig) => {
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
if (await FileExists(modelPath)) {
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
const newModelPath = modelPath.replace(/\.pth$/, '.st');
ConvertSafetensors(modelPath, newModelPath).then(async () => {
if (!await FileExists(newModelPath)) {
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
}
}).catch(e => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
else
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
});
setTimeout(WindowShow, 1000);
} else {
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
}
};
export const convertToGGML = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
const ok = await checkDependencies(navigate);
if (!ok)
return;
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
if (await FileExists(modelPath)) {
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
const precision: Precision = selectedConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
const newModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
ConvertGGML(commonStore.settings.customPythonPath, modelPath, newModelPath, precision === 'Q5_1').then(async () => {
if (!await FileExists(newModelPath)) {
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
}
}).catch(e => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
else
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
});
setTimeout(WindowShow, 1000);
} else {
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
}
};

View File

@@ -15,13 +15,18 @@ import { toast } from 'react-toastify';
import { t } from 'i18next';
import { ToastOptions } from 'react-toastify/dist/types';
import { Button } from '@fluentui/react-components';
import { Language, Languages, SettingsType } from '../pages/Settings';
import { ModelSourceItem } from '../pages/Models';
import { ModelConfig, ModelParameters } from '../pages/Configs';
import { DownloadStatus } from '../pages/Downloads';
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
import { BrowserOpenURL, WindowShow } from '../../wailsjs/runtime';
import { BrowserOpenURL, EventsOff, EventsOn, WindowShow } from '../../wailsjs/runtime';
import { NavigateFunction } from 'react-router';
import { ModelConfig, ModelParameters } from '../types/configs';
import { DownloadStatus } from '../types/downloads';
import { ModelSourceItem } from '../types/models';
import { Language, Languages, SettingsType } from '../types/settings';
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
import { InstrumentTypeNameMap, MidiMessage, tracksMinimalTotalTime } from '../types/composition';
import logo from '../assets/images/logo.png';
import { Preset } from '../types/presets';
import { botName, Conversation, MessageType, userName } from '../types/chat';
import { v4 as uuid } from 'uuid';
export type Cache = {
version: string
@@ -39,7 +44,9 @@ export type LocalConfig = {
}
export async function refreshBuiltInModels(readCache: boolean = false) {
let cache: { models: ModelSourceItem[] } = { models: [] };
let cache: {
models: ModelSourceItem[]
} = { models: [] };
if (readCache)
await ReadJson('cache.json').then((cacheData: Cache) => {
if (cacheData.models)
@@ -56,7 +63,7 @@ export async function refreshBuiltInModels(readCache: boolean = false) {
return cache;
}
const modelSuffix = ['.pth', '.st', '.safetensors'];
const modelSuffix = ['.pth', '.st', '.safetensors', '.bin'];
export async function refreshLocalModels(cache: {
models: ModelSourceItem[]
@@ -72,7 +79,8 @@ export async function refreshLocalModels(cache: {
size: d.size,
lastUpdated: d.modTime,
isComplete: true,
isLocal: true
isLocal: true,
tags: ['Local']
}] as ModelSourceItem[];
return [];
}));
@@ -82,12 +90,15 @@ export async function refreshLocalModels(cache: {
for (let i = 0; i < cache.models.length; i++) {
if (!cache.models[i].lastUpdatedMs)
cache.models[i].lastUpdatedMs = Date.parse(cache.models[i].lastUpdated);
if (!cache.models[i].tags)
cache.models[i].tags = ['Other'];
for (let j = i + 1; j < cache.models.length; j++) {
if (!cache.models[j].lastUpdatedMs)
cache.models[j].lastUpdatedMs = Date.parse(cache.models[j].lastUpdated);
if (cache.models[i].name === cache.models[j].name) {
const tags = Array.from(new Set([...cache.models[i].tags as string[], ...cache.models[j].tags as string[]]));
if (cache.models[i].size <= cache.models[j].size) { // j is local file
if (cache.models[i].lastUpdatedMs! < cache.models[j].lastUpdatedMs!) {
cache.models[i] = Object.assign({}, cache.models[i], cache.models[j]);
@@ -97,6 +108,7 @@ export async function refreshLocalModels(cache: {
} // else is not complete local file
cache.models[i].isLocal = true;
cache.models[i].localSize = cache.models[j].size;
cache.models[i].tags = tags;
cache.models.splice(j, 1);
j--;
}
@@ -117,7 +129,7 @@ function initLastUnfinishedModelDownloads() {
{
name: item.name,
path: `${commonStore.settings.customModelsPath}/${item.name}`,
url: item.downloadUrl!,
url: getHfDownloadUrl(item.downloadUrl!),
transferred: item.localSize!,
size: item.size,
speed: 0,
@@ -131,7 +143,9 @@ function initLastUnfinishedModelDownloads() {
commonStore.setLastUnfinishedModelDownloads(list);
}
export async function refreshRemoteModels(cache: { models: ModelSourceItem[] }) {
export async function refreshRemoteModels(cache: {
models: ModelSourceItem[]
}) {
const manifestUrls = commonStore.modelSourceManifestList.split(/[,;\n]/);
const requests = manifestUrls.filter(url => url.endsWith('.json')).map(
url => fetch(url, { cache: 'no-cache' }).then(r => r.json()));
@@ -178,14 +192,14 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break;
case 'WebGPU':
strategy += params.precision === 'int8' ? 'fp16i8' : 'fp16';
strategy += params.precision === 'nf4' ? 'fp16i4' : params.precision === 'int8' ? 'fp16i8' : 'fp16';
break;
case 'CUDA':
case 'CUDA-Beta':
if (avoidOverflow)
strategy = params.useCustomCuda ? 'cuda fp16 *1 -> ' : 'cuda fp32 *1 -> ';
strategy += 'cuda ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
strategy += params.precision === 'int8' ? 'fp16i8' : params.precision === 'fp32' ? 'fp32' : 'fp16';
if (params.storedLayers < params.maxStoredLayers)
strategy += ` *${params.storedLayers}+`;
break;
@@ -193,7 +207,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
if (avoidOverflow)
strategy = 'mps fp32 *1 -> ';
strategy += 'mps ';
strategy += params.precision === 'fp16' ? 'fp16' : params.precision === 'int8' ? 'fp16i8' : 'fp32';
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break;
case 'Custom':
strategy = params.customStrategy || '';
@@ -289,7 +303,27 @@ export function bytesToReadable(size: number) {
else return bytesToGb(size) + ' GB';
}
export function getServerRoot(defaultLocalPort: number, isCore: boolean = false) {
const coreCustomApiUrl = commonStore.settings.coreApiUrl.trim().replace(/\/$/, '');
if (isCore && coreCustomApiUrl)
return coreCustomApiUrl;
const defaultRoot = `http://127.0.0.1:${defaultLocalPort}`;
if (commonStore.status.status !== ModelStatus.Offline)
return defaultRoot;
const customApiUrl = commonStore.settings.apiUrl.trim().replace(/\/$/, '');
if (customApiUrl)
return customApiUrl;
if (commonStore.platform === 'web')
return '';
return defaultRoot;
}
export function absPathAsset(path: string) {
if (commonStore.platform === 'web')
return path;
if (path === logo)
return path;
if ((path.length > 0 && path[0] === '/') ||
(path.length > 1 && path[1] === ':')) {
return '=>' + path;
@@ -322,27 +356,47 @@ export async function checkUpdate(notifyEvenLatest: boolean = false) {
`https://gitee.com/josc146/RWKV-Runner/releases/download/${versionTag}/${asset.name}`;
toastWithButton(t('New Version Available') + ': ' + versionTag, t('Update'), () => {
DeleteFile('cache.json');
toast(t('Downloading update, please wait. If it is not completed, please manually download the program from GitHub and replace the original program.'), {
type: 'info',
position: 'bottom-left',
autoClose: false
});
setTimeout(() => {
UpdateApp(updateUrl).then(() => {
toast(t('Update completed, please restart the program.'), {
type: 'success',
position: 'bottom-left',
autoClose: false
}
);
}).catch((e) => {
toast(t('Update Error') + ' - ' + (e.message || e), {
type: 'error',
const progressId = 'update_app';
const progressEvent = 'updateApp';
const updateProgress = (ds: DownloadStatus | null) => {
const content =
t('Downloading update, please wait. If it is not completed, please manually download the program from GitHub and replace the original program.')
+ (ds ? ` (${ds.progress.toFixed(2)}% ${bytesToReadable(ds.transferred)}/${bytesToReadable(ds.size)})` : '');
const options: ToastOptions = {
type: 'info',
position: 'bottom-left',
autoClose: false,
toastId: progressId,
hideProgressBar: false,
progress: ds ? ds.progress / 100 : 0
};
if (toast.isActive(progressId))
toast.update(progressId, {
render: content,
...options
});
else
toast(content, options);
};
updateProgress(null);
EventsOn(progressEvent, updateProgress);
UpdateApp(updateUrl).then(() => {
toast(t('Update completed, please restart the program.'), {
type: 'success',
position: 'bottom-left',
autoClose: false
});
}
);
}).catch((e) => {
toast(t('Update Error') + ' - ' + (e.message || e), {
type: 'error',
position: 'bottom-left',
autoClose: false
});
}, 500);
}).finally(() => {
toast.dismiss(progressId);
EventsOff(progressEvent);
});
}, {
autoClose: false,
position: 'bottom-left'
@@ -383,7 +437,10 @@ export const checkDependencies = async (navigate: NavigateFunction) => {
toastWithButton(`${t('Downloading')} Python`, t('Check'), () => {
navigate({ pathname: '/downloads' });
}, { autoClose: 3000 });
AddToDownloadList('python-3.10.11-embed-amd64.zip', 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip');
AddToDownloadList('python-3.10.11-embed-amd64.zip',
!commonStore.settings.cnMirror
? 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip'
: 'https://mirrors.huaweicloud.com/python/3.10.11/python-3.10.11-embed-amd64.zip');
});
} else if (depErrorMsg.includes('DepCheck Error')) {
if (depErrorMsg.includes('vc_redist') || depErrorMsg.includes('DLL load failed while importing')) {
@@ -440,6 +497,107 @@ export function toastWithButton(text: string, buttonText: string, onClickButton:
return id;
}
export function getHfDownloadUrl(url: string) {
if (commonStore.settings.useHfMirror && url.includes('huggingface.co') && url.includes('resolve'))
return url.replace('huggingface.co', 'hf-mirror.com');
return url;
}
export function refreshTracksTotalTime() {
if (commonStore.tracks.length === 0) {
commonStore.setTrackTotalTime(tracksMinimalTotalTime);
commonStore.setTrackCurrentTime(0);
commonStore.setTrackPlayStartTime(0);
return;
}
const endTimes = commonStore.tracks.map(t => t.offsetTime + t.contentTime);
const totalTime = Math.max(...endTimes) + tracksMinimalTotalTime;
if (commonStore.trackPlayStartTime > totalTime)
commonStore.setTrackPlayStartTime(totalTime);
commonStore.setTrackTotalTime(totalTime);
}
export function getMidiRawContentTime(rawContent: MidiMessage[]) {
return rawContent.reduce((sum, current) =>
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
, 0);
}
export function getMidiRawContentMainInstrument(rawContent: MidiMessage[]) {
const sortedInstrumentFrequency = Object.entries(rawContent
.filter(c => c.messageType === 'NoteOn')
.map(c => c.instrument)
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
, {} as {
[key: string]: number
}))
.sort((a, b) => b[1] - a[1]);
let mainInstrument: string = '';
if (sortedInstrumentFrequency.length > 0)
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
return mainInstrument;
}
export function flushMidiRecordingContent() {
const recordingTrackIndex = commonStore.tracks.findIndex(t => t.id === commonStore.recordingTrackId);
if (recordingTrackIndex >= 0) {
const recordingTrack = commonStore.tracks[recordingTrackIndex];
const tracks = commonStore.tracks.slice();
tracks[recordingTrackIndex] = {
...recordingTrack,
content: commonStore.recordingContent,
rawContent: commonStore.recordingRawContent,
contentTime: getMidiRawContentTime(commonStore.recordingRawContent),
mainInstrument: getMidiRawContentMainInstrument(commonStore.recordingRawContent)
};
commonStore.setTracks(tracks);
refreshTracksTotalTime();
}
commonStore.setRecordingContent('');
commonStore.setRecordingRawContent([]);
}
export async function getSoundFont() {
let soundUrl: string;
if (commonStore.compositionParams.useLocalSoundFont)
soundUrl = 'assets/sound-font';
else
soundUrl = !commonStore.settings.giteeUpdatesSource ?
`https://raw.githubusercontent.com/josStorer/sgm_plus/master` :
`https://cdn.jsdelivr.net/gh/josstorer/sgm_plus`;
const fallbackUrl = 'https://cdn.jsdelivr.net/gh/josstorer/sgm_plus';
await fetch(soundUrl + '/soundfont.json').then(r => {
if (!r.ok)
soundUrl = fallbackUrl;
}).catch(() => soundUrl = fallbackUrl);
return soundUrl;
}
export const setActivePreset = (preset: Preset | null) => {
commonStore.setActivePreset(preset);
//TODO if (preset.displayPresetMessages) {
const conversation: Conversation = {};
const conversationOrder: string[] = [];
if (preset)
for (const message of preset.messages) {
const newUuid = uuid();
conversationOrder.push(newUuid);
conversation[newUuid] = {
sender: message.role === 'user' ? userName : botName,
type: MessageType.Normal,
color: message.role === 'user' ? 'brand' : 'colorful',
time: new Date().toISOString(),
content: message.content,
side: message.role === 'user' ? 'right' : 'left',
done: true
};
}
commonStore.setConversation(conversation);
commonStore.setConversationOrder(conversationOrder);
//}
};
export function getSupportedCustomCudaFile(isBeta: boolean) {
if ([' 10', ' 16', ' 20', ' 30', 'MX', 'Tesla P', 'Quadro P', 'NVIDIA P', 'TITAN X', 'TITAN RTX', 'RTX A',
'Quadro RTX 4000', 'Quadro RTX 5000', 'Tesla T4', 'NVIDIA A10', 'NVIDIA A40'].some(v => commonStore.status.device_name.includes(v)))

View File

@@ -0,0 +1,77 @@
import { getDocument, GlobalWorkerOptions, PDFDocumentProxy } from 'pdfjs-dist';
import { TextItem } from 'pdfjs-dist/types/src/display/api';
export function webOpenOpenFileDialog(filterPattern: string, fnStartLoading: Function | undefined): Promise<{
blob: Blob,
content?: string
}> {
return new Promise((resolve, reject) => {
const input = document.createElement('input');
input.type = 'file';
input.accept = filterPattern
.replaceAll('*.txt', 'text/plain')
.replace('*.midi', 'audio/midi')
.replace('*.mid', 'audio/midi')
.replaceAll('*.', 'application/')
.replaceAll(';', ',');
input.onchange = async e => {
// @ts-ignore
const file: Blob = e.target?.files[0];
if (fnStartLoading && typeof fnStartLoading === 'function')
fnStartLoading();
if (!GlobalWorkerOptions.workerSrc && file.type === 'application/pdf')
// @ts-ignore
GlobalWorkerOptions.workerSrc = await import('pdfjs-dist/build/pdf.worker.min.mjs');
if (file.type === 'text/plain') {
const reader = new FileReader();
reader.readAsText(file, 'UTF-8');
reader.onload = event => {
const content = event.target?.result as string;
resolve({
blob: file,
content: content
});
};
reader.onerror = reject;
} else if (file.type === 'application/pdf') {
const readPDFPage = async (doc: PDFDocumentProxy, pageNo: number) => {
const page = await doc.getPage(pageNo);
const tokenizedText = await page.getTextContent();
return tokenizedText.items.map(token => (token as TextItem).str).join('');
};
let reader = new FileReader();
reader.readAsArrayBuffer(file);
reader.onload = async (event) => {
try {
const doc = await getDocument(event.target?.result!).promise;
const pageTextPromises = [];
for (let pageNo = 1; pageNo <= doc.numPages; pageNo++) {
pageTextPromises.push(readPDFPage(doc, pageNo));
}
const pageTexts = await Promise.all(pageTextPromises);
let content;
if (pageTexts.length === 1)
content = pageTexts[0];
else
content = pageTexts.map((p, i) => `Page ${i + 1}:\n${p}`).join('\n\n');
resolve({
blob: file,
content: content
});
} catch (err) {
reject(err);
}
};
reader.onerror = reject;
} else {
resolve({
blob: file
});
}
};
input.click();
});
}

Some files were not shown because too many files have changed in this diff Show More