Compare commits

...

72 Commits

Author SHA1 Message Date
josc146
30b6d66a2d release v1.4.1 2023-07-28 22:14:53 +08:00
josc146
9d89b6f4db fix params 2023-07-28 22:13:19 +08:00
josc146
d2928e54f7 fix failed to build cyac 2023-07-28 21:40:17 +08:00
josc146
49ba5c97f7 update readme 2023-07-28 13:13:14 +08:00
github-actions[bot]
4054fac359 release v1.4.0 2023-07-28 05:06:42 +00:00
josc146
dfae1d9645 release v1.4.0 2023-07-28 13:05:55 +08:00
josc146
0f16a0dd1b remove LoraFinetunePrecision fp32 2023-07-28 12:53:41 +08:00
josc146
cb05a8a2ae update manifest 2023-07-28 12:50:39 +08:00
josc146
a51385173c add CPU-120M-Music config 2023-07-28 12:45:31 +08:00
josc146
4e18222a35 improve RunButton prompt 2023-07-28 12:45:13 +08:00
josc146
daabcf58a0 add Composition Page (RWKV-Music) 2023-07-28 12:30:05 +08:00
josc146
d0fd480bd6 chore 2023-07-26 22:24:26 +08:00
josc146
1df345b5eb improve embeddings API results 2023-07-25 20:30:43 +08:00
josc146
77868c798b chore 2023-07-25 16:37:06 +08:00
josc146
f56748a941 improve python backend startup speed 2023-07-25 16:14:29 +08:00
josc146
29c5b1d804 add midi api 2023-07-25 16:11:17 +08:00
josc146
34095a6c36 support for stop array 2023-07-25 16:10:22 +08:00
josc146
05b9b42b56 add support for MIDI RWKV 2023-07-25 16:09:31 +08:00
josc146
211ae342af improve sse fetch 2023-07-25 15:59:37 +08:00
josc146
5ae683e915 update presets 2023-07-25 15:53:25 +08:00
josc146
dc59fb39c7 update readme 2023-07-18 14:21:09 +08:00
josc146
49960774ee update readme 2023-07-18 14:16:50 +08:00
github-actions[bot]
b718452618 release v1.3.9 2023-07-17 05:05:17 +00:00
josc146
15ae312b37 release v1.3.9 2023-07-17 13:03:32 +08:00
josc146
6938b5b20e change chinese translation of completion 2023-07-17 13:03:11 +08:00
josc146
9b3b06ab04 fix input with array type (#96, #107) 2023-07-17 12:59:45 +08:00
josc146
e2a7c93753 fix always show Convert Failed when converting model 2023-07-16 16:54:18 +08:00
github-actions[bot]
34349aee0b release v1.3.8 2023-07-15 14:29:14 +00:00
josc146
8e79370e95 release v1.3.8 2023-07-15 22:28:49 +08:00
josc146
652c35322b save conversation as txt (originally in md) 2023-07-15 22:12:59 +08:00
josc146
e2fc57ac24 training: fix data EOL format 2023-07-11 12:19:39 +08:00
josc146
994fc7c828 fix cross-device state cache exception 2023-07-11 11:20:12 +08:00
josc146
b9a960d984 update readme 2023-07-10 23:06:19 +08:00
josc146
3baf260f4d update readme 2023-07-10 22:59:22 +08:00
github-actions[bot]
d037ded146 release v1.3.7 2023-07-10 13:50:05 +00:00
josc146
622287f3da release v1.3.7 2023-07-10 21:49:33 +08:00
josc146
5d12bf74f6 update presets 2023-07-10 21:43:58 +08:00
josc146
c88f9321f5 update manifest 2023-07-10 20:49:31 +08:00
josc146
f9f1d5c9fc improve /completions api compatibility 2023-07-10 20:45:08 +08:00
josc146
0edec68376 improve training data path compatibility 2023-07-10 20:44:09 +08:00
josc146
ee63dc25f4 update readme 2023-07-09 13:56:36 +08:00
josc146
fee8fe73f2 fix loss parser 2023-07-09 13:33:06 +08:00
github-actions[bot]
1689f9e7e7 release v1.3.6 2023-07-09 04:41:11 +00:00
josc146
a1ed0cb2e9 release v1.3.6 2023-07-09 12:40:42 +08:00
josc146
5ee5fa7e6e fix load_state_dict crash 2023-07-09 12:33:29 +08:00
josc146
d8c70453ec format 2023-07-09 12:32:50 +08:00
josc146
e930eb5967 extra vc check 2023-07-09 12:18:51 +08:00
josc146
aec6ad636a chore 2023-07-09 12:10:14 +08:00
josc146
750c91bd3e update logo 2023-07-09 11:59:23 +08:00
josc146
fcc3886db1 improve error messages for training 2023-07-09 11:39:44 +08:00
josc146
22afc98be5 fix loss parser 2023-07-09 11:32:05 +08:00
josc146
5b1a9448e6 fix jsonl data when using directory as training data 2023-07-09 11:31:07 +08:00
github-actions[bot]
07d89e3eeb release v1.3.5 2023-07-07 13:58:33 +00:00
josc146
96e97d9c1e release v1.3.5 2023-07-07 21:58:08 +08:00
josc146
bcb125e168 support using directory as training data 2023-07-07 21:57:01 +08:00
josc146
6fbb86667c improve python script error messages 2023-07-07 20:16:35 +08:00
josc146
2d545604f4 refresh local models in real-time (#98) 2023-07-07 20:14:55 +08:00
josc146
7210a7481e improve finetune guide 2023-07-07 19:10:31 +08:00
josc146
55210c89e2 improve wsl dependencies installation 2023-07-07 18:57:51 +08:00
josc146
c725d11dd9 fix loss parser 2023-07-07 13:56:08 +08:00
josc146
ba2a6bd06c update Related Repositories 2023-07-07 13:54:57 +08:00
josc146
57b80c6ed0 fix build for macos and linux 2023-07-07 13:54:07 +08:00
josc146
115c59d5e1 chore 2023-07-07 13:53:39 +08:00
github-actions[bot]
543ff468b7 release v1.3.4 2023-07-03 14:32:06 +00:00
josc146
96ae47989e release v1.3.4 2023-07-03 22:31:37 +08:00
josc146
368932a610 improve finetune compatibility 2023-07-03 22:28:01 +08:00
josc146
f2cd531fcb fix build for macos and linux 2023-07-03 22:22:55 +08:00
josc146
511652b71c improve finetune compatibility 2023-07-03 22:19:20 +08:00
github-actions[bot]
525fb132d6 release v1.3.3 2023-07-03 13:40:51 +00:00
josc146
5acb1fd958 release v1.3.3 2023-07-03 21:40:22 +08:00
josc146
76761ee453 improve lora finetune process (need to be refactored) 2023-07-03 21:40:16 +08:00
github-actions[bot]
134b2884e6 release v1.3.2 2023-07-03 09:43:01 +00:00
65 changed files with 25203 additions and 921 deletions

1
.gitattributes vendored
View File

@@ -2,6 +2,7 @@ backend-python/rwkv_pip/** linguist-vendored
backend-python/wkv_cuda_utils/** linguist-vendored
backend-python/get-pip.py linguist-vendored
backend-python/convert_model.py linguist-vendored
backend-python/utils/midi.py linguist-vendored
build/** linguist-vendored
finetune/lora/** linguist-vendored
finetune/json2binidx_tool/** linguist-vendored

View File

@@ -11,7 +11,7 @@ env:
jobs:
create-draft:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
- uses: actions/checkout@v3
@@ -35,7 +35,7 @@ jobs:
gh release create ${{github.ref_name}} -d -F CURRENT_CHANGE.md -t ${{github.ref_name}}
windows:
runs-on: windows-latest
runs-on: windows-2022
needs: create-draft
steps:
- uses: actions/checkout@v3
@@ -56,10 +56,10 @@ jobs:
Expand-Archive ./python-3.10.11-embed-amd64.zip -DestinationPath ./py310
$content=Get-Content "./py310/python310._pth"; $content | ForEach-Object {if ($_.ReadCount -eq 3) {"Lib\\site-packages"} else {$_}} | Set-Content ./py310/python310._pth
./py310/python ./backend-python/get-pip.py
./py310/python -m pip install Cython
./py310/python -m pip install Cython==0.29.36
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../include" -Destination "py310/include" -Recurse
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../libs" -Destination "py310/libs" -Recurse
./py310/python -m pip install cyac
./py310/python -m pip install cyac==1.7
go install github.com/wailsapp/wails/v2/cmd/wails@latest
make
Rename-Item -Path "build/bin/RWKV-Runner.exe" -NewName "RWKV-Runner_windows_x64.exe"
@@ -83,6 +83,9 @@ jobs:
go install github.com/wailsapp/wails/v2/cmd/wails@latest
rm -rf ./backend-python/wkv_cuda_utils
rm ./backend-python/get-pip.py
sed -i '1,2d' ./backend-golang/wsl_not_windows.go
rm ./backend-golang/wsl.go
mv ./backend-golang/wsl_not_windows.go ./backend-golang/wsl.go
make
mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64
@@ -102,6 +105,9 @@ jobs:
go install github.com/wailsapp/wails/v2/cmd/wails@latest
rm -rf ./backend-python/wkv_cuda_utils
rm ./backend-python/get-pip.py
sed -i '' '1,2d' ./backend-golang/wsl_not_windows.go
rm ./backend-golang/wsl.go
mv ./backend-golang/wsl_not_windows.go ./backend-golang/wsl.go
make
cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt
cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal
@@ -110,7 +116,7 @@ jobs:
- run: gh release upload ${{github.ref_name}} build/bin/RWKV-Runner_macos_universal.zip build/bin/RWKV-Runner_darwin_universal
publish-release:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
needs: [ windows, linux, macos ]
steps:
- uses: actions/checkout@v3

1
.gitignore vendored
View File

@@ -23,3 +23,4 @@ __pycache__
*.log
train_log.txt
finetune/json2binidx_tool/data
/wsl.state

View File

@@ -1,7 +1,18 @@
## Changes
- lora finetune (need to be refactored)
- reduce package size for linux and macos
- add Composition Page (RWKV-Music)
- improve RunButton prompt
- support for `stop` array api params
- improve embeddings API results
- improve python backend startup speed
- add support for MIDI RWKV
- add midi api
- add CPU-120M-Music config
- improve sse fetch
- update manifest (a lot of new models)
- update presets
- remove LoraFinetunePrecision fp32
- chore
## Install

View File

@@ -49,7 +49,7 @@ English | [简体中文](README_ZH.md) | [日本語](README_JA.md)
#### Default configs has enabled custom CUDA kernel acceleration, which is much faster and consumes much less VRAM. If you encounter possible compatibility issues, go to the Configs page and turn off `Use Custom CUDA kernel to Accelerate`.
#### If Windows Defender claims this is a virus, you can try downloading [v1.0.8](https://github.com/josStorer/RWKV-Runner/releases/tag/v1.0.8)/[v1.0.9](https://github.com/josStorer/RWKV-Runner/releases/tag/v1.0.9) and letting it update automatically to the latest version, or add it to the trusted list.
#### If Windows Defender claims this is a virus, you can try downloading [v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip) and letting it update automatically to the latest version, or add it to the trusted list (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`).
#### For different tasks, adjusting API parameters can achieve better results. For example, for translation tasks, you can try setting Temperature to 1 and Top_P to 0.3.
@@ -64,6 +64,8 @@ English | [简体中文](README_ZH.md) | [日本語](README_JA.md)
- Easy-to-understand and operate parameter configuration
- Built-in model conversion tool
- Built-in download management and remote model inspection
- Built-in one-click LoRA Finetune
- Can also be used as an OpenAI ChatGPT and GPT-Playground client
- Multilingual localization
- Theme switching
- Automatic updates
@@ -89,6 +91,9 @@ body.json:
## Embeddings API Example
Note: v1.4.0 has improved the quality of embeddings API. The generated results are not compatible
with previous versions. If you are using embeddings API to generate knowledge bases or similar, please regenerate.
If you are using langchain, just use `OpenAIEmbeddings(openai_api_base="http://127.0.0.1:8000", openai_api_key="sk-")`
```python
@@ -126,46 +131,49 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
```
## Todo
- [ ] Model training functionality
- [x] CUDA operator int8 acceleration
- [x] macOS support
- [x] Linux support
- [ ] Local State Cache DB
## Related Repositories:
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
- RWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
- ChatRWKV: https://github.com/BlinkDL/ChatRWKV
- RWKV-LM: https://github.com/BlinkDL/RWKV-LM
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
## Preview
### Homepage
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/60efbb65-29e3-4346-a597-5bdcd099251c)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/d7f24d80-f382-428d-8b28-edf87e1549e2)
### Chat
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6cde9c45-51bb-4dee-b1fe-746862448520)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/80009872-528f-4932-aeb2-f724fa892e7c)
### Completion
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/52f47f92-d21d-4cd7-b04e-d6f9af937a97)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/bf49de8e-3b89-4543-b1ef-7cd4b19a1836)
### Composition
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e8ad908d-3fd2-4e92-bcdb-96815cb836ee)
### Configuration
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/93270a68-9d6d-4247-b6a3-e543c65a876b)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/48befdc6-e03c-4851-9bee-22f77ee2640e)
### Model Management
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6f96fdd3-fdf5-4b78-af80-2afbd1ad173b)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/367fe4f8-cc12-475f-9371-3cf62cdbf293)
### Download Management
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6982e7ee-bace-4a88-bb47-92379185bf9d)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c8153cf9-c8cb-4618-8268-60c82a5be539)
### LoRA Finetune
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/4715045a-683e-4d2a-9b0e-090c7a5df63f)
### Settings
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b3b2ab46-344c-4f04-b066-1503f776eeb9)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/1067e635-8c07-4217-86a8-e48a5fcbb075)

View File

@@ -24,22 +24,32 @@
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [プレビュー](#Preview) | [ダウンロード][download-url] | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
[license-url]: https://github.com/josStorer/RWKV-Runner/blob/master/LICENSE
[release-image]: https://img.shields.io/github/release/josStorer/RWKV-Runner.svg
[release-url]: https://github.com/josStorer/RWKV-Runner/releases/latest
[download-url]: https://github.com/josStorer/RWKV-Runner/releases
[Windows-image]: https://img.shields.io/badge/-Windows-blue?logo=windows
[Windows-url]: https://github.com/josStorer/RWKV-Runner/blob/master/build/windows/Readme_Install.txt
[MacOS-image]: https://img.shields.io/badge/-MacOS-black?logo=apple
[MacOS-url]: https://github.com/josStorer/RWKV-Runner/blob/master/build/darwin/Readme_Install.txt
[Linux-image]: https://img.shields.io/badge/-Linux-black?logo=linux
[Linux-url]: https://github.com/josStorer/RWKV-Runner/blob/master/build/linux/Readme_Install.txt
</div>
#### デフォルトの設定はカスタム CUDA カーネルアクセラレーションを有効にしています。互換性の問題が発生する可能性がある場合は、コンフィグページに移動し、`Use Custom CUDA kernel to Accelerate` をオフにしてください。
#### Windows Defender がこれをウイルスだと主張する場合は、[v1.0.8](https://github.com/josStorer/RWKV-Runner/releases/tag/v1.0.8) / [v1.0.9](https://github.com/josStorer/RWKV-Runner/releases/tag/v1.0.9) をダウンロードして最新版に自動更新させるか、信頼済みリストに追加してみてください
#### Windows Defender がこれをウイルスだと主張する場合は、[v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip) をダウンロードして最新版に自動更新させるか、信頼済みリストに追加してみてください (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`)
#### 異なるタスクについては、API パラメータを調整することで、より良い結果を得ることができます。例えば、翻訳タスクの場合、Temperature を 1 に、Top_P を 0.3 に設定してみてください。
@@ -54,6 +64,8 @@
- 分かりやすく操作しやすいパラメータ設定
- 内蔵モデル変換ツール
- ダウンロード管理とリモートモデル検査機能内蔵
- 内蔵のLoRA微調整機能を搭載しています
- このプログラムは、OpenAI ChatGPTとGPT Playgroundのクライアントとしても使用できます
- 多言語ローカライズ
- テーマ切り替え
- 自動アップデート
@@ -79,7 +91,11 @@ body.json:
## 埋め込み API の例
LangChain を使用している場合は、`OpenAIEmbeddings(openai_api_base="http://127.0.0.1:8000", openai_api_key="sk-")`を使用してください
Note: v1.4.0 has improved the quality of embeddings API. The generated results are not compatible
with previous versions. If you are using embeddings API to generate knowledge bases or similar, please regenerate.
LangChain を使用している場合は、`OpenAIEmbeddings(openai_api_base="http://127.0.0.1:8000", openai_api_key="sk-")`
を使用してください
```python
import numpy as np
@@ -116,46 +132,49 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
```
## Todo
- [ ] モデル学習機能
- [x] CUDA オペレータ int8 アクセラレーション
- [x] macOS サポート
- [x] Linux サポート
- [ ] ローカルステートキャッシュ DB
## 関連リポジトリ:
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
- RWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
- ChatRWKV: https://github.com/BlinkDL/ChatRWKV
- RWKV-LM: https://github.com/BlinkDL/RWKV-LM
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
## プレビュー
### ホームページ
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/60efbb65-29e3-4346-a597-5bdcd099251c)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/d7f24d80-f382-428d-8b28-edf87e1549e2)
### チャット
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6cde9c45-51bb-4dee-b1fe-746862448520)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/80009872-528f-4932-aeb2-f724fa892e7c)
### 補完
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/52f47f92-d21d-4cd7-b04e-d6f9af937a97)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/bf49de8e-3b89-4543-b1ef-7cd4b19a1836)
### 作曲
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e8ad908d-3fd2-4e92-bcdb-96815cb836ee)
### コンフィグ
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/93270a68-9d6d-4247-b6a3-e543c65a876b)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/48befdc6-e03c-4851-9bee-22f77ee2640e)
### モデル管理
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6f96fdd3-fdf5-4b78-af80-2afbd1ad173b)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/367fe4f8-cc12-475f-9371-3cf62cdbf293)
### ダウンロード管理
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/6982e7ee-bace-4a88-bb47-92379185bf9d)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/c8153cf9-c8cb-4618-8268-60c82a5be539)
### LoRA Finetune
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/4715045a-683e-4d2a-9b0e-090c7a5df63f)
### 設定
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/b3b2ab46-344c-4f04-b066-1503f776eeb9)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/1067e635-8c07-4217-86a8-e48a5fcbb075)

View File

@@ -20,7 +20,7 @@ API兼容的接口这意味着一切ChatGPT客户端都是RWKV客户端。
[![MacOS][MacOS-image]][MacOS-url]
[![Linux][Linux-image]][Linux-url]
[视频演示](https://www.bilibili.com/video/BV1hM4y1v76R) | [疑难解答](https://www.bilibili.com/read/cv23921171) | [预览](#Preview) | [下载][download-url] | [懒人包](https://pan.baidu.com/s/1wchIUHgne3gncIiLIeKBEQ?pwd=1111) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
[视频演示](https://www.bilibili.com/video/BV1hM4y1v76R) | [疑难解答](https://www.bilibili.com/read/cv23921171) | [预览](#Preview) | [下载][download-url] | [懒人包](https://pan.baidu.com/s/1zdzZ_a0uM3gDqi6pXIZVAA?pwd=1111) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
@@ -46,11 +46,9 @@ API兼容的接口这意味着一切ChatGPT客户端都是RWKV客户端。
</div>
#### 注意 目前RWKV中文模型质量一般推荐使用英文模型或World(全球语言)体验实际RWKV能力
#### 预设配置已经开启自定义CUDA算子加速速度更快且显存消耗更少。如果你遇到可能的兼容性问题前往配置页面关闭`使用自定义CUDA算子加速`
#### 如果Windows Defender说这是一个病毒你可以尝试下载[v1.0.8](https://github.com/josStorer/RWKV-Runner/releases/tag/v1.0.8)/[v1.0.9](https://github.com/josStorer/RWKV-Runner/releases/tag/v1.0.9)然后让其自动更新到最新版,或添加信任
#### 如果Windows Defender说这是一个病毒你可以尝试下载[v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip),然后让其自动更新到最新版,或添加信任 (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`)
#### 对于不同的任务调整API参数会获得更好的效果例如对于翻译任务你可以尝试设置Temperature为1Top_P为0.3
@@ -60,10 +58,12 @@ API兼容的接口这意味着一切ChatGPT客户端都是RWKV客户端。
- 与OpenAI API完全兼容一切ChatGPT客户端都是RWKV客户端。启动模型后打开 http://127.0.0.1:8000/docs 查看详细内容
- 全自动依赖安装,你只需要一个轻巧的可执行程序
- 预设了2G至32G显存的配置几乎在各种电脑上工作良好
- 自带用户友好的聊天和补全交互页面
- 自带用户友好的聊天和续写交互页面
- 易于理解和操作的参数配置
- 内置模型转换工具
- 内置下载管理和远程模型检视
- 内置一键LoRA微调
- 也可用作 OpenAI ChatGPT 和 GPT Playground 客户端
- 多语言本地化
- 主题切换
- 自动更新
@@ -89,6 +89,8 @@ body.json:
## Embeddings API 示例
注意: 1.4.0 版本对embeddings API质量进行了改善生成结果与之前的版本不兼容如果你正在使用此API生成知识库等请重新生成
如果你在用langchain, 直接使用 `OpenAIEmbeddings(openai_api_base="http://127.0.0.1:8000", openai_api_key="sk-")`
```python
@@ -126,46 +128,49 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
```
## Todo
- [ ] 模型训练功能
- [x] CUDA算子int8提速
- [x] macOS支持
- [x] linux支持
- [ ] 本地状态缓存数据库
## 相关仓库:
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
- RWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
- ChatRWKV: https://github.com/BlinkDL/ChatRWKV
- RWKV-LM: https://github.com/BlinkDL/RWKV-LM
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
## Preview
### 主页
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/9d25380a-a17b-443f-b823-86c754ebebf0)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/ff2b1eef-dd3b-4cbf-98fb-b5a1ecee43e1)
### 聊天
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/0e66d5fa-f34a-409f-9cd4-d880815733f3)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/9570e73b-dca2-4316-9e92-09961f3c48c4)
### 补全
### 续写
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/d4178ee9-a188-4878-9777-25c916872c29)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/69f9ba7a-2fe8-4a5e-94cb-aa655aa409e2)
### 作曲
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/95b34893-80c2-4706-87f9-bc141032ed4b)
### 配置
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/ad9921fc-7248-40a3-9e18-03445b86e4bf)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/59460f69-b172-4c7a-86cb-573262543076)
### 模型管理
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/7c36f15f-3e77-49cd-a16d-99a29f870bdf)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/551121ee-1bfe-421b-a9d1-24125126ab4b)
### 下载管理
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/32fde30b-11dd-43b9-9667-ad6975be2106)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/cc076038-2a91-4d36-bd39-266020e8ea87)
### LoRA微调
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/31939b8f-9546-4f44-b434-295b492ec625)
### 设置
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/e8a0f746-9da7-48e3-b3fc-e1453ac50de2)
![image](https://github.com/josStorer/RWKV-Runner/assets/13366013/9652d7cc-ac33-4587-a8fb-03e5a6f5ea77)

Binary file not shown.

View File

@@ -0,0 +1,116 @@
# https://github.com/magenta/magenta-js/issues/164
import json
import os
import urllib.request
def get_pitches_array(min_pitch, max_pitch):
return list(range(min_pitch, max_pitch + 1))
base_url = 'https://storage.googleapis.com/magentadata/js/soundfonts'
soundfont_path = 'sgm_plus'
soundfont_json_url = f"{base_url}/{soundfont_path}/soundfont.json"
# Download soundfont.json
soundfont_json = ""
if not os.path.exists('soundfont.json'):
try:
with urllib.request.urlopen(soundfont_json_url) as response:
soundfont_json = response.read()
# Save soundfont.json
with open('soundfont.json', 'wb') as file:
file.write(soundfont_json)
except:
print("Failed to download soundfont.json")
else:
# If file exists, get it from the file system
with open('soundfont.json', 'rb') as file:
soundfont_json = file.read()
# Parse soundfont.json
soundfont_data = json.loads(soundfont_json)
if soundfont_data is not None:
# Iterate over each instrument
for instrument_id, instrument_name in soundfont_data['instruments'].items():
if not os.path.isdir(instrument_name):
# Create instrument directory if it doesn't exist
os.makedirs(instrument_name)
instrument_json = ""
instrument_path = f"{soundfont_path}/{instrument_name}"
if not os.path.exists(f"{instrument_name}/instrument.json"):
# Download instrument.json
instrument_json_url = f"{base_url}/{instrument_path}/instrument.json"
try:
with urllib.request.urlopen(instrument_json_url) as response:
instrument_json = response.read()
# Save instrument.json
with open(f"{instrument_name}/instrument.json", 'wb') as file:
file.write(instrument_json)
except:
print(f"Failed to download {instrument_name}/instrument.json")
else:
# If file exists, get it from the file system
with open(f"{instrument_name}/instrument.json", 'rb') as file:
instrument_json = file.read()
# Parse instrument.json
instrument_data = json.loads(instrument_json)
if instrument_data is not None:
# Iterate over each pitch and velocity
for velocity in instrument_data['velocities']:
pitches = get_pitches_array(instrument_data['minPitch'], instrument_data['maxPitch'])
for pitch in pitches:
# Create the file name
file_name = f'p{pitch}_v{velocity}.mp3'
# Check if the file already exists
if os.path.exists(f"{instrument_name}/{file_name}"):
pass
#print(f"Skipping {instrument_name}/{file_name} - File already exists")
else:
# Download pitch/velocity file
file_url = f"{base_url}/{instrument_path}/{file_name}"
try:
with urllib.request.urlopen(file_url) as response:
file_contents = response.read()
# Save pitch/velocity file
with open(f"{instrument_name}/{file_name}", 'wb') as file:
file.write(file_contents)
print(f"Downloaded {instrument_name}/{file_name}")
except:
print(f"Failed to download {instrument_name}/{file_name}")
else:
print(f"Failed to parse instrument.json for {instrument_name}")
else:
print('Failed to parse soundfont.json')

View File

@@ -0,0 +1,134 @@
{
"name": "sgm_plus",
"instruments": {
"0": "acoustic_grand_piano",
"1": "bright_acoustic_piano",
"2": "electric_grand_piano",
"3": "honkytonk_piano",
"4": "electric_piano_1",
"5": "electric_piano_2",
"6": "harpsichord",
"7": "clavichord",
"8": "celesta",
"9": "glockenspiel",
"10": "music_box",
"11": "vibraphone",
"12": "marimba",
"13": "xylophone",
"14": "tubular_bells",
"15": "dulcimer",
"16": "drawbar_organ",
"17": "percussive_organ",
"18": "rock_organ",
"19": "church_organ",
"20": "reed_organ",
"21": "accordion",
"22": "harmonica",
"23": "tango_accordion",
"24": "acoustic_guitar_nylon",
"25": "acoustic_guitar_steel",
"26": "electric_guitar_jazz",
"27": "electric_guitar_clean",
"28": "electric_guitar_muted",
"29": "overdriven_guitar",
"30": "distortion_guitar",
"31": "guitar_harmonics",
"32": "acoustic_bass",
"33": "electric_bass_finger",
"34": "electric_bass_pick",
"35": "fretless_bass",
"36": "slap_bass_1",
"37": "slap_bass_2",
"38": "synth_bass_1",
"39": "synth_bass_2",
"40": "violin",
"41": "viola",
"42": "cello",
"43": "contrabass",
"44": "tremolo_strings",
"45": "pizzicato_strings",
"46": "orchestral_harp",
"47": "timpani",
"48": "string_ensemble_1",
"49": "string_ensemble_2",
"50": "synthstrings_1",
"51": "synthstrings_2",
"52": "choir_aahs",
"53": "voice_oohs",
"54": "synth_voice",
"55": "orchestra_hit",
"56": "trumpet",
"57": "trombone",
"58": "tuba",
"59": "muted_trumpet",
"60": "french_horn",
"61": "brass_section",
"62": "synthbrass_1",
"63": "synthbrass_2",
"64": "soprano_sax",
"65": "alto_sax",
"66": "tenor_sax",
"67": "baritone_sax",
"68": "oboe",
"69": "english_horn",
"70": "bassoon",
"71": "clarinet",
"72": "piccolo",
"73": "flute",
"74": "recorder",
"75": "pan_flute",
"76": "blown_bottle",
"77": "shakuhachi",
"78": "whistle",
"79": "ocarina",
"80": "lead_1_square",
"81": "lead_2_sawtooth",
"82": "lead_3_calliope",
"83": "lead_4_chiff",
"84": "lead_5_charang",
"85": "lead_6_voice",
"86": "lead_7_fifths",
"87": "lead_8_bass_lead",
"88": "pad_1_new_age",
"89": "pad_2_warm",
"90": "pad_3_polysynth",
"91": "pad_4_choir",
"92": "pad_5_bowed",
"93": "pad_6_metallic",
"94": "pad_7_halo",
"95": "pad_8_sweep",
"96": "fx_1_rain",
"97": "fx_2_soundtrack",
"98": "fx_3_crystal",
"99": "fx_4_atmosphere",
"100": "fx_5_brightness",
"101": "fx_6_goblins",
"102": "fx_7_echoes",
"103": "fx_8_scifi",
"104": "sitar",
"105": "banjo",
"106": "shamisen",
"107": "koto",
"108": "kalimba",
"109": "bag_pipe",
"110": "fiddle",
"111": "shanai",
"112": "tinkle_bell",
"113": "agogo",
"114": "steel_drums",
"115": "woodblock",
"116": "taiko_drum",
"117": "melodic_tom",
"118": "synth_drum",
"119": "reverse_cymbal",
"120": "guitar_fret_noise",
"121": "breath_noise",
"122": "seashore",
"123": "bird_tweet",
"124": "telephone_ring",
"125": "helicopter",
"126": "applause",
"127": "gunshot",
"drums": "percussion"
}
}

469
assets/soundfont_builder.rb Normal file
View File

@@ -0,0 +1,469 @@
#!/usr/bin/env ruby
#
# JavaScript Soundfont Builder for MIDI.js
# Author: 0xFE <mohit@muthanna.com>
# edited by Valentijn Nieman <valentijnnieman@gmail.com>
#
# Requires:
#
# FluidSynth
# Lame
# Ruby Gems: midilib parallel
#
# $ brew install fluidsynth lame (on OSX)
# $ gem install midilib parallel
#
# You'll need to download a GM soundbank to generate audio.
#
# Usage:
#
# 1) Install the above dependencies.
# 2) Edit BUILD_DIR, SOUNDFONT, and INSTRUMENTS as required.
# 3) Run without any argument.
require 'base64'
require 'digest/sha1'
require 'etc'
require 'fileutils'
require 'midilib'
require 'parallel'
require 'zlib'
require 'json'
include FileUtils
BUILD_DIR = "./sound-font" # Output path
SOUNDFONT = "./default_sound_font.sf2" # Soundfont file path
# This script will generate MIDI.js-compatible instrument JS files for
# all instruments in the below array. Add or remove as necessary.
INSTRUMENTS = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
127
]
# It was found that midilib uses names that are incompatible with MIDI.js
# For example, midilib uses "SynthBrass 1" -> https://github.com/jimm/midilib/blob/6c8e481ae72cd9f00a38eb3700ddfca6b549f153/lib/midilib/consts.rb#L280
# and the MIDI association uses "SynthBrass 1" -> https://www.midi.org/specifications-old/item/gm-level-1-sound-set
# but the MIDI.js calls this "Synth Brass 1" -> https://github.com/mudcube/MIDI.js/blob/a8a84257afa70721ae462448048a87301fc1554a/js/midi/gm.js#L44
# there are others like "Bag pipe" vs "Bagpipe", etc.
# here, we use the MIDI.js definitions because that is how most users will interact with the generated soundfonts.
MIDIJS_PATCH_NAMES = [
"Acoustic Grand Piano",
"Bright Acoustic Piano",
"Electric Grand Piano",
"Honky-tonk Piano",
"Electric Piano 1",
"Electric Piano 2",
"Harpsichord",
"Clavinet",
"Celesta",
"Glockenspiel",
"Music Box",
"Vibraphone",
"Marimba",
"Xylophone",
"Tubular Bells",
"Dulcimer",
"Drawbar Organ",
"Percussive Organ",
"Rock Organ",
"Church Organ",
"Reed Organ",
"Accordion",
"Harmonica",
"Tango Accordion",
"Acoustic Guitar (nylon)",
"Acoustic Guitar (steel)",
"Electric Guitar (jazz)",
"Electric Guitar (clean)",
"Electric Guitar (muted)",
"Overdriven Guitar",
"Distortion Guitar",
"Guitar Harmonics",
"Acoustic Bass",
"Electric Bass (finger)",
"Electric Bass (pick)",
"Fretless Bass",
"Slap Bass 1",
"Slap Bass 2",
"Synth Bass 1",
"Synth Bass 2",
"Violin",
"Viola",
"Cello",
"Contrabass",
"Tremolo Strings",
"Pizzicato Strings",
"Orchestral Harp",
"Timpani",
"String Ensemble 1",
"String Ensemble 2",
"Synth Strings 1",
"Synth Strings 2",
"Choir Aahs",
"Voice Oohs",
"Synth Choir",
"Orchestra Hit",
"Trumpet",
"Trombone",
"Tuba",
"Muted Trumpet",
"French Horn",
"Brass Section",
"Synth Brass 1",
"Synth Brass 2",
"Soprano Sax",
"Alto Sax",
"Tenor Sax",
"Baritone Sax",
"Oboe",
"English Horn",
"Bassoon",
"Clarinet",
"Piccolo",
"Flute",
"Recorder",
"Pan Flute",
"Blown Bottle",
"Shakuhachi",
"Whistle",
"Ocarina",
"Lead 1 (square)",
"Lead 2 (sawtooth)",
"Lead 3 (calliope)",
"Lead 4 (chiff)",
"Lead 5 (charang)",
"Lead 6 (voice)",
"Lead 7 (fifths)",
"Lead 8 (bass + lead)",
"Pad 1 (new age)",
"Pad 2 (warm)",
"Pad 3 (polysynth)",
"Pad 4 (choir)",
"Pad 5 (bowed)",
"Pad 6 (metallic)",
"Pad 7 (halo)",
"Pad 8 (sweep)",
"FX 1 (rain)",
"FX 2 (soundtrack)",
"FX 3 (crystal)",
"FX 4 (atmosphere)",
"FX 5 (brightness)",
"FX 6 (goblins)",
"FX 7 (echoes)",
"FX 8 (sci-fi)",
"Sitar",
"Banjo",
"Shamisen",
"Koto",
"Kalimba",
"Bagpipe",
"Fiddle",
"Shanai",
"Tinkle Bell",
"Agogo",
"Steel Drums",
"Woodblock",
"Taiko Drum",
"Melodic Tom",
"Synth Drum",
"Reverse Cymbal",
"Guitar Fret Noise",
"Breath Noise",
"Seashore",
"Bird Tweet",
"Telephone Ring",
"Helicopter",
"Applause",
"Gunshot"
]
# The encoders and tools are expected in your PATH. You can supply alternate
# paths by changing the constants below.
LAME = "lame" # `which lame`.chomp
FLUIDSYNTH = "fluidsynth" # `which fluidsynth`.chomp
puts "Building the following instruments using font: " + SOUNDFONT
# Display instrument names.
INSTRUMENTS.each do |i|
puts " #{i}: " + MIDIJS_PATCH_NAMES[i]
end
puts
puts "Using MP3 encoder: " + LAME
puts "Using FluidSynth encoder: " + FLUIDSYNTH
puts
puts "Sending output to: " + BUILD_DIR
puts
raise "Can't find soundfont: #{SOUNDFONT}" unless File.exist? SOUNDFONT
raise "Can't find 'lame' command" if LAME.empty?
raise "Can't find 'fluidsynth' command" if FLUIDSYNTH.empty?
raise "Output directory does not exist: #{BUILD_DIR}" unless File.exist?(BUILD_DIR)
puts "Hit return to begin."
$stdin.readline
NOTES = {
"C" => 0,
"Db" => 1,
"D" => 2,
"Eb" => 3,
"E" => 4,
"F" => 5,
"Gb" => 6,
"G" => 7,
"Ab" => 8,
"A" => 9,
"Bb" => 10,
"B" => 11
}
MIDI_C0 = 12
VELOCITY = 100
DURATION = Integer(3000)
TEMP_FILE = "#{BUILD_DIR}/%s%stemp.midi"
FLUIDSYNTH_RAW = "%s.wav"
def deflate(string, level)
z = Zlib::Deflate.new(level)
dst = z.deflate(string, Zlib::FINISH)
z.close
dst
end
def note_to_int(note, octave)
value = NOTES[note]
increment = MIDI_C0 * octave
return value + increment
end
def int_to_note(value)
raise "Bad Value" if value < MIDI_C0
reverse_notes = NOTES.invert
value -= MIDI_C0
octave = value / 12
note = value % 12
return { key: reverse_notes[note],
octave: octave }
end
# Run a quick table validation
MIDI_C0.upto(100) do |x|
note = int_to_note x
#raise "Broken table" unless note_to_int(note[:key], note[:octave]) == x
end
def generate_midi(program, note_value, file)
include MIDI
seq = Sequence.new()
track = Track.new(seq)
seq.tracks << track
track.events << ProgramChange.new(0, Integer(program))
track.events << NoteOn.new(0, note_value, VELOCITY, 0) # channel, note, velocity, delta
track.events << NoteOff.new(0, note_value, VELOCITY, DURATION)
File.open(file, 'wb') { | file | seq.write(file) }
end
def run_command(cmd)
puts "Running: " + cmd
`#{cmd}`
end
def midi_to_audio(source, target)
run_command "#{FLUIDSYNTH} -C no -R no -g 0.5 -F #{target} #{SOUNDFONT} #{source}"
run_command "#{LAME} -v -b 8 -B 64 #{target}"
rm target
end
def open_js_file(instrument_key, type)
js_file = File.open("#{BUILD_DIR}/#{instrument_key}-#{type}.js", "w")
js_file.write(
"""
if (typeof(MIDI) === 'undefined') var MIDI = {};
if (typeof(MIDI.Soundfont) === 'undefined') MIDI.Soundfont = {};
MIDI.Soundfont.#{instrument_key} = {
""")
return js_file
end
def close_js_file(file)
file.write("\n}\n")
file.close
end
def base64js(note, file, type)
output = '"' + note + '": '
output += '"' + "data:audio/#{type};base64,"
output += Base64.strict_encode64(File.read(file)) + '"'
return output
end
def generate_audio(program)
instrument = MIDIJS_PATCH_NAMES[program]
instrument_key = instrument.downcase.gsub(/[^a-z0-9 ]/, "").gsub(/[ ]/, "_")
puts "Generating audio for: " + instrument + "(#{instrument_key})"
mkdir_p "#{BUILD_DIR}/#{instrument_key}"
note_to_int("A", 0).upto(note_to_int("C", 8)) do |note_value|
output_name = "p#{note_value}_v#{VELOCITY}"
output_path_prefix = BUILD_DIR + "/#{instrument_key}" + output_name
puts "Generating: #{output_name}"
temp_file_specific = TEMP_FILE % [output_name, instrument_key]
generate_midi(program, note_value, temp_file_specific)
midi_to_audio(temp_file_specific, output_path_prefix + ".wav")
mv output_path_prefix + ".mp3", "#{BUILD_DIR}/#{instrument_key}/#{output_name}.mp3"
rm temp_file_specific
end
tempHash = {
"name" => instrument_key,
"minPitch" => 0,
"maxPitch" => 127,
"durationSeconds" => 3.0,
"releaseSeconds" => 1.0,
"percussive": false,
"velocities": [100]
}
File.open("#{BUILD_DIR}/#{instrument_key}/instrument.json", "w") do |f|
f.write(tempHash.to_json)
end
end
Parallel.each(INSTRUMENTS, :in_processes=>Etc.nprocessors){|i| generate_audio(i)}

View File

@@ -41,6 +41,14 @@ func (a *App) OnStartup(ctx context.Context) {
a.cmdPrefix = "cd " + a.exDir + " && "
}
os.Mkdir(a.exDir+"models", os.ModePerm)
os.Mkdir(a.exDir+"lora-models", os.ModePerm)
os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm)
f, err := os.Create(a.exDir + "lora-models/train_log.txt")
if err == nil {
f.Close()
}
a.downloadLoop()
watcher, err := fsnotify.NewWatcher()

View File

@@ -122,6 +122,10 @@ func (a *App) CopyFile(src string, dst string) error {
}
func (a *App) OpenSaveFileDialog(filterPattern string, defaultFileName string, savedContent string) (string, error) {
return a.OpenSaveFileDialogBytes(filterPattern, defaultFileName, []byte(savedContent))
}
func (a *App) OpenSaveFileDialogBytes(filterPattern string, defaultFileName string, savedContent []byte) (string, error) {
path, err := wruntime.SaveFileDialog(a.ctx, wruntime.SaveDialogOptions{
DefaultFilename: defaultFileName,
Filters: []wruntime.FileFilter{{
@@ -135,7 +139,7 @@ func (a *App) OpenSaveFileDialog(filterPattern string, defaultFileName string, s
if path == "" {
return "", nil
}
if err := os.WriteFile(path, []byte(savedContent), 0644); err != nil {
if err := os.WriteFile(path, savedContent, 0644); err != nil {
return "", err
}
return path, nil

View File

@@ -1,6 +1,7 @@
package backend_golang
import (
"encoding/json"
"errors"
"os"
"os/exec"
@@ -43,6 +44,39 @@ func (a *App) ConvertData(python string, input string, outputPrefix string, voca
if strings.Contains(vocab, "rwkv_vocab_v20230424") {
tokenizerType = "RWKVTokenizer"
}
input = strings.TrimSuffix(input, "/")
if fi, err := os.Stat(input); err == nil && fi.IsDir() {
files, err := os.ReadDir(input)
if err != nil {
return "", err
}
jsonlFile, err := os.Create(outputPrefix + ".jsonl")
if err != nil {
return "", err
}
defer jsonlFile.Close()
for _, file := range files {
if file.IsDir() || !strings.HasSuffix(file.Name(), ".txt") {
continue
}
textContent, err := os.ReadFile(input + "/" + file.Name())
if err != nil {
return "", err
}
textJson, err := json.Marshal(map[string]string{"text": strings.ReplaceAll(strings.ReplaceAll(string(textContent), "\r\n", "\n"), "\r", "\n")})
if err != nil {
return "", err
}
if _, err := jsonlFile.WriteString(string(textJson) + "\n"); err != nil {
return "", err
}
}
input = outputPrefix + ".jsonl"
} else if err != nil {
return "", err
}
return Cmd(python, "./finetune/json2binidx_tool/tools/preprocess_data.py", "--input", input, "--output-prefix", outputPrefix, "--vocab", vocab,
"--tokenizer-type", tokenizerType, "--dataset-impl", "mmap", "--append-eod")
}
@@ -113,3 +147,11 @@ func (a *App) InstallPyDep(python string, cnMirror bool) (string, error) {
return Cmd(python, "-m", "pip", "install", "-r", "./backend-python/requirements_without_cyac.txt")
}
}
func (a *App) GetPyError() string {
content, err := os.ReadFile("./error.txt")
if err != nil {
return ""
}
return string(content)
}

View File

@@ -1,3 +1,5 @@
//go:build windows
package backend_golang
import (
@@ -8,7 +10,6 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
@@ -37,10 +38,6 @@ func isWslRunning() (bool, error) {
}
func (a *App) WslStart() error {
if runtime.GOOS != "windows" {
return errors.New("wsl not supported")
}
running, err := isWslRunning()
if err != nil {
return err
@@ -100,10 +97,6 @@ func (a *App) WslStart() error {
}
func (a *App) WslCommand(command string) error {
if runtime.GOOS != "windows" {
return errors.New("wsl not supported")
}
running, err := isWslRunning()
if err != nil {
return err
@@ -119,10 +112,6 @@ func (a *App) WslCommand(command string) error {
}
func (a *App) WslStop() error {
if runtime.GOOS != "windows" {
return errors.New("wsl not supported")
}
running, err := isWslRunning()
if err != nil {
return err
@@ -130,8 +119,10 @@ func (a *App) WslStop() error {
if !running {
return errors.New("wsl not running")
}
err = cmd.Process.Kill()
cmd = nil
if cmd != nil {
err = cmd.Process.Kill()
cmd = nil
}
// stdin.Close()
stdin = nil
distro = nil
@@ -142,10 +133,6 @@ func (a *App) WslStop() error {
}
func (a *App) WslIsEnabled() error {
if runtime.GOOS != "windows" {
return errors.New("wsl not supported")
}
ex, err := os.Executable()
if err != nil {
return err
@@ -177,10 +164,6 @@ func (a *App) WslIsEnabled() error {
}
func (a *App) WslEnable(forceMode bool) error {
if runtime.GOOS != "windows" {
return errors.New("wsl not supported")
}
cmd := `/online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux`
_, err := su.ShellExecute(su.RUNAS, "dism", cmd, `C:\`)
if err != nil {
@@ -193,10 +176,6 @@ func (a *App) WslEnable(forceMode bool) error {
}
func (a *App) WslInstallUbuntu() error {
if runtime.GOOS != "windows" {
return errors.New("wsl not supported")
}
exec.Command("start", "ms-windows-store://pdp/?ProductId=9PN20MSR04DW").Start()
return nil
_, err := Cmd("ms-windows-store://pdp/?ProductId=9PN20MSR04DW")
return err
}

View File

@@ -0,0 +1,31 @@
//go:build darwin || linux
package backend_golang
import (
"errors"
)
func (a *App) WslStart() error {
return errors.New("wsl not supported")
}
func (a *App) WslCommand(command string) error {
return errors.New("wsl not supported")
}
func (a *App) WslStop() error {
return errors.New("wsl not supported")
}
func (a *App) WslIsEnabled() error {
return errors.New("wsl not supported")
}
func (a *App) WslEnable(forceMode bool) error {
return errors.New("wsl not supported")
}
func (a *App) WslInstallUbuntu() error {
return errors.New("wsl not supported")
}

View File

@@ -219,13 +219,17 @@ def get_args():
return p.parse_args()
args = get_args()
if not args.quiet:
print(f"** {args}")
try:
args = get_args()
if not args.quiet:
print(f"** {args}")
RWKV(
getattr(args, "in"),
args.strategy,
verbose=not args.quiet,
convert_and_save_and_exit=args.out,
)
RWKV(
getattr(args, "in"),
args.strategy,
verbose=not args.quiet,
convert_and_save_and_exit=args.out,
)
except Exception as e:
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -1,3 +1,5 @@
import midi2audio
import mido
import lm_dataformat
import ftfy
import tqdm

View File

@@ -12,7 +12,7 @@ from utils.rwkv import *
from utils.torch import *
from utils.ngrok import *
from utils.log import log_middleware
from routes import completion, config, state_cache
from routes import completion, config, state_cache, midi
import global_var
app = FastAPI(dependencies=[Depends(log_middleware)])
@@ -27,6 +27,7 @@ app.add_middleware(
app.include_router(completion.router)
app.include_router(config.router)
app.include_router(midi.router)
app.include_router(state_cache.router)
@@ -41,12 +42,12 @@ def init():
ngrok_connect()
@app.get("/")
@app.get("/", tags=["Root"])
def read_root():
return {"Hello": "World!"}
@app.post("/exit")
@app.post("/exit", tags=["Root"])
def exit():
parent_pid = os.getpid()
parent = psutil.Process(parent_pid)
@@ -55,20 +56,9 @@ def exit():
parent.kill()
def debug():
model = RWKV(
model="../models/RWKV-4-Raven-7B-v11-Eng49%-Chn49%-Jpn1%-Other1%-20230430-ctx8192.pth",
strategy="cuda fp16",
tokens_path="20B_tokenizer.json",
)
d = model.pipeline.decode([])
print(d)
if __name__ == "__main__":
uvicorn.run(
"main:app",
port=8000 if len(sys.argv) < 2 else int(sys.argv[1]),
host="127.0.0.1" if len(sys.argv) < 3 else sys.argv[2],
)
# debug()

Binary file not shown.

View File

@@ -1,7 +1,7 @@
import asyncio
import json
from threading import Lock
from typing import List
from typing import List, Union
import base64
from fastapi import APIRouter, Request, status, HTTPException
@@ -25,7 +25,7 @@ class ChatCompletionBody(ModelConfigBody):
messages: List[Message]
model: str = "rwkv"
stream: bool = False
stop: str = None
stop: Union[str, List[str]] = None
class Config:
schema_extra = {
@@ -44,10 +44,10 @@ class ChatCompletionBody(ModelConfigBody):
class CompletionBody(ModelConfigBody):
prompt: str
prompt: Union[str, List[str]]
model: str = "rwkv"
stream: bool = False
stop: str = None
stop: Union[str, List[str]] = None
class Config:
schema_extra = {
@@ -72,7 +72,7 @@ requests_num = 0
async def eval_rwkv(
model: RWKV,
model: AbstractRWKV,
request: Request,
body: ModelConfigBody,
prompt: str,
@@ -206,10 +206,10 @@ async def eval_rwkv(
}
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
@router.post("/v1/chat/completions", tags=["Completions"])
@router.post("/chat/completions", tags=["Completions"])
async def chat_completions(body: ChatCompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
model: TextRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
@@ -299,16 +299,19 @@ The following is a coherent verbose detailed conversation between a girl named {
return None
@router.post("/v1/completions")
@router.post("/completions")
@router.post("/v1/completions", tags=["Completions"])
@router.post("/completions", tags=["Completions"])
async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
model: AbstractRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.prompt is None or body.prompt == "":
if body.prompt is None or body.prompt == "" or body.prompt == []:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
if type(body.prompt) == list:
body.prompt = body.prompt[0] # TODO: support multiple prompts
if body.stream:
return EventSourceResponse(
eval_rwkv(model, request, body, body.prompt, body.stream, body.stop, False)
@@ -323,7 +326,7 @@ async def completions(body: CompletionBody, request: Request):
class EmbeddingsBody(BaseModel):
input: str or List[str] or List[List[int]]
input: Union[str, List[str], List[List[int]]]
model: str = "rwkv"
encoding_format: str = None
fast_mode: bool = False
@@ -343,12 +346,12 @@ def embedding_base64(embedding: List[float]) -> str:
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
@router.post("/v1/embeddings")
@router.post("/embeddings")
@router.post("/v1/engines/text-embedding-ada-002/embeddings")
@router.post("/engines/text-embedding-ada-002/embeddings")
@router.post("/v1/embeddings", tags=["Embeddings"])
@router.post("/embeddings", tags=["Embeddings"])
@router.post("/v1/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
@router.post("/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
async def embeddings(body: EmbeddingsBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
model: AbstractRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")

View File

@@ -13,13 +13,16 @@ router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
)
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
elif "midi" in model_path:
return tokenizer_dir + "tokenizer-midi.json"
else:
return default_tokens_path
@@ -39,7 +42,7 @@ class SwitchModelBody(BaseModel):
}
@router.post("/switch-model")
@router.post("/switch-model", tags=["Configs"])
def switch_model(body: SwitchModelBody, response: Response, request: Request):
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = Status.HTTP_304_NOT_MODIFIED
@@ -52,13 +55,27 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "":
return "success"
if "->" in body.strategy:
state_cache.disable_state_cache()
else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
global_var.set(
global_var.Model,
RWKV(
TextRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),
)
if "midi" not in body.model.lower()
else MusicRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),
@@ -81,7 +98,7 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
return "success"
@router.post("/update-config")
@router.post("/update-config", tags=["Configs"])
def update_config(body: ModelConfigBody):
"""
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
@@ -93,7 +110,7 @@ def update_config(body: ModelConfigBody):
return "success"
@router.get("/status")
@router.get("/status", tags=["Configs"])
def status():
gpus = GPUtil.getGPUs()
if len(gpus) == 0:

View File

@@ -0,0 +1,131 @@
import io
from fastapi import APIRouter, HTTPException, status
from starlette.responses import StreamingResponse
from pydantic import BaseModel
from utils.midi import *
from midi2audio import FluidSynth
router = APIRouter()
class TextToMidiBody(BaseModel):
text: str
class Config:
schema_extra = {
"example": {
"text": "p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:2d:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:2d:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:1f:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:26:a g:39:a g:39:a g:3e:a g:3e:a g:42:a g:42:a pi:39:a pi:3e:a pi:42:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0",
}
}
@router.post("/text-to-midi", tags=["MIDI"])
def text_to_midi(body: TextToMidiBody):
vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config)
mid = convert_str_to_midi(cfg, body.text.strip())
mid_data = io.BytesIO()
mid.save(None, mid_data)
mid_data.seek(0)
return StreamingResponse(mid_data, media_type="audio/midi")
class TxtToMidiBody(BaseModel):
txt_path: str
midi_path: str
class Config:
schema_extra = {
"example": {
"txt_path": "midi/sample.txt",
"midi_path": "midi/sample.mid",
}
}
@router.post("/txt-to-midi", tags=["MIDI"])
def txt_to_midi(body: TxtToMidiBody):
if not body.midi_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config)
with open(body.txt_path, "r") as f:
text = f.read()
text = text.strip()
mid = convert_str_to_midi(cfg, text)
mid.save(body.midi_path)
return "success"
class MidiToWavBody(BaseModel):
midi_path: str
wav_path: str
sound_font_path: str = "assets/default_sound_font.sf2"
class Config:
schema_extra = {
"example": {
"midi_path": "midi/sample.mid",
"wav_path": "midi/sample.wav",
"sound_font_path": "assets/default_sound_font.sf2",
}
}
@router.post("/midi-to-wav", tags=["MIDI"])
def midi_to_wav(body: MidiToWavBody):
"""
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
"""
if not body.wav_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
fs = FluidSynth(body.sound_font_path)
fs.midi_to_audio(body.midi_path, body.wav_path)
return "success"
class TextToWavBody(BaseModel):
text: str
wav_name: str
sound_font_path: str = "assets/default_sound_font.sf2"
class Config:
schema_extra = {
"example": {
"text": "p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:2d:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:2d:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:1f:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:26:a g:39:a g:39:a g:3e:a g:3e:a g:42:a g:42:a pi:39:a pi:3e:a pi:42:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0",
"wav_name": "sample",
"sound_font_path": "assets/default_sound_font.sf2",
}
}
@router.post("/text-to-wav", tags=["MIDI"])
def text_to_wav(body: TextToWavBody):
"""
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
"""
text = body.text.strip()
if not text.startswith("<start>"):
text = "<start> " + text
if not text.endswith("<end>"):
text = text + " <end>"
txt_path = f"midi/{body.wav_name}.txt"
midi_path = f"midi/{body.wav_name}.mid"
wav_path = f"midi/{body.wav_name}.wav"
with open(txt_path, "w") as f:
f.write(text)
txt_to_midi(TxtToMidiBody(txt_path=txt_path, midi_path=midi_path))
midi_to_wav(
MidiToWavBody(
midi_path=midi_path, wav_path=wav_path, sound_font_path=body.sound_font_path
)
)
return "success"

View File

@@ -4,8 +4,6 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
import sys
import torch
router = APIRouter()
@@ -34,6 +32,32 @@ def init():
print("cyac not found")
@router.post("/disable-state-cache", tags=["State Cache"])
def disable_state_cache():
global trie, dtrie
trie = None
dtrie = {}
gc.collect()
return "success"
@router.post("/enable-state-cache", tags=["State Cache"])
def enable_state_cache():
global trie, dtrie
try:
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()
return "success"
except ModuleNotFoundError:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
class AddStateBody(BaseModel):
prompt: str
tokens: List[str]
@@ -41,12 +65,14 @@ class AddStateBody(BaseModel):
logits: Any
@router.post("/add-state")
@router.post("/add-state", tags=["State Cache"])
def add_state(body: AddStateBody):
global trie, dtrie, loop_del_trie_id
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import torch
try:
id: int = trie.insert(body.prompt)
device: torch.device = body.state[0].device
@@ -79,12 +105,14 @@ def add_state(body: AddStateBody):
)
@router.post("/reset-state")
@router.post("/reset-state", tags=["State Cache"])
def reset_state():
global trie, dtrie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()
@@ -113,12 +141,14 @@ def _get_a_dtrie_buff_size(dtrie_v):
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
@router.post("/longest-prefix-state")
@router.post("/longest-prefix-state", tags=["State Cache"])
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import torch
id = -1
try:
for id, len in trie.prefix(body.prompt):
@@ -150,7 +180,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
}
@router.post("/save-state")
@router.post("/save-state", tags=["State Cache"])
def save_state():
global trie
if trie is None:

20144
backend-python/rwkv_pip/tokenizer-midi.json vendored Normal file

File diff suppressed because it is too large Load Diff

685
backend-python/utils/midi.py vendored Normal file
View File

@@ -0,0 +1,685 @@
# https://github.com/briansemrau/MIDI-LLM-tokenizer
# MIT License
# Copyright (c) 2023 Brian Semrau
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import json
import random
from dataclasses import dataclass
from functools import lru_cache
from math import ceil, floor, log
from typing import Dict, Iterator, List, Optional, Tuple
import mido
@dataclass
class VocabConfig:
# Number of note events. Should be 128.
note_events: int
# Number of wait events. Configurable, must evenly divide max_wait_time.
wait_events: int
# Max wait time in milliseconds to be represented by a single token.
max_wait_time: int
# Number of velocity events. Should be 128 (or 100? need to check midi standard)
velocity_events: int
# Number of bins to quantize velocity into. Should evenly divide velocity_events.
velocity_bins: int
# Exponential scaling factor for velocity bin sizes. 1.0 = linear scaling.
velocity_exp: float
# Whether to sort tokens by instrument, note. This should improve data reducibility.
do_token_sorting: bool
# Whether tokens should be represented as combined instrument/note/velocity tokens, or separate tokens for each.
unrolled_tokens: bool
# If non-zero, notes held for this many seconds will be automatically released during str->midi decoding.
decode_end_held_note_delay: float
# If true, repeated notes will be automatically released before playing again during str->midi decoding.
decode_fix_repeated_notes: bool
# List of instrument names to use for binning. Must have at most 16 values.
bin_instrument_names: List[str]
# Indicates which bin name represents percussion instruments on MIDI channel 10.
ch10_instrument_bin_name: str
# Mapping from instrument name to bin name.
program_name_to_bin_name: Dict[str, str]
# Mapping from bin name to program name.
bin_name_to_program_name: Dict[str, str]
# Mapping from program number to instrument name.
instrument_names: Dict[str, str]
def __post_init__(self):
self.validate()
self._instrument_names_str_to_int = {
name: int(i) for i, name in self.instrument_names.items()
}
self._instrument_names_int_to_str = {
int(i): name for i, name in self.instrument_names.items()
}
self._bin_str_to_int = {
name: int(i) for i, name in enumerate(self.bin_instrument_names)
}
self._bin_int_to_instrument_int = [
self._instrument_names_str_to_int[self.bin_name_to_program_name[name]]
if name != self.ch10_instrument_bin_name
else 0
for name in self.bin_instrument_names
]
self._instrument_int_to_bin_int = [
self._bin_str_to_int[self.program_name_to_bin_name[instr]]
if self.program_name_to_bin_name[instr] != ""
else -1
for instr in self.program_name_to_bin_name.keys()
]
self._ch10_bin_int = (
self._bin_str_to_int[self.ch10_instrument_bin_name]
if self.ch10_instrument_bin_name
else -1
)
self.short_instr_bin_names = []
for instr in self.bin_instrument_names:
i = min(1, len(instr))
while instr[:i] in self.short_instr_bin_names:
i += 1
self.short_instr_bin_names.append(instr[:i])
self._short_instrument_names_str_to_int = {
name: int(i) for i, name in enumerate(self.short_instr_bin_names)
}
range_excluding_ch10 = [
(i if i < 9 else i + 1) for i in range(len(self.bin_instrument_names))
]
bins_excluding_ch10 = [
n for n in self.bin_instrument_names if n != self.ch10_instrument_bin_name
]
self.bin_channel_map = {
bin: channel
for channel, bin in zip(range_excluding_ch10, bins_excluding_ch10)
}
if self.ch10_instrument_bin_name:
self.bin_channel_map[self.ch10_instrument_bin_name] = 9
def validate(self):
if self.max_wait_time % self.wait_events != 0:
raise ValueError("max_wait_time must be exactly divisible by wait_events")
if self.velocity_bins < 2:
raise ValueError("velocity_bins must be at least 2")
if len(self.bin_instrument_names) > 16:
raise ValueError("bin_instruments must have at most 16 values")
if (
self.ch10_instrument_bin_name
and self.ch10_instrument_bin_name not in self.bin_instrument_names
):
raise ValueError("ch10_instrument_bin_name must be in bin_instruments")
if self.velocity_exp <= 0:
raise ValueError("velocity_exp must be greater than 0")
@classmethod
def from_json(cls, path: str):
with open(path, "r") as f:
config = json.load(f)
return cls(**config)
class VocabUtils:
def __init__(self, cfg: VocabConfig) -> None:
self.cfg = cfg
@lru_cache(maxsize=128)
def format_wait_token(self, wait: int) -> str:
return f"t{wait}"
@lru_cache(maxsize=128)
def format_note_token(
self, instrument_bin: int, note: int, velocity_bin: int
) -> str:
return f"{self.cfg.short_instr_bin_names[instrument_bin]}:{note:x}:{velocity_bin:x}"
def format_unrolled_note(self, note: int) -> str:
return f"n{note:x}"
def format_unrolled_velocity(self, velocity_bin: int) -> str:
return f"v{velocity_bin:x}"
def format_unrolled_instrument_bin(self, instrument_bin: int) -> str:
return f"i{self.cfg.short_instr_bin_names[instrument_bin]}"
def velocity_to_bin(self, velocity: float) -> int:
velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
if self.cfg.velocity_exp == 1.0:
return ceil(velocity / binsize)
else:
return ceil(
(
self.cfg.velocity_events
* (
(
self.cfg.velocity_exp
** (velocity / self.cfg.velocity_events)
- 1.0
)
/ (self.cfg.velocity_exp - 1.0)
)
)
/ binsize
)
def bin_to_velocity(self, bin: int) -> int:
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
if self.cfg.velocity_exp == 1.0:
return max(0, ceil(bin * binsize - 1))
else:
return max(
0,
ceil(
self.cfg.velocity_events
* log(
((self.cfg.velocity_exp - 1) * binsize * bin)
/ self.cfg.velocity_events
+ 1,
self.cfg.velocity_exp,
)
- 1
),
)
def delta_to_wait_ids(self, delta_ms: float) -> Iterator[int]:
def roundi(f: float):
return ceil(f - 0.5)
max_wait_ms = self.cfg.max_wait_time
div = max_wait_ms / self.cfg.wait_events
# if delta_ms // max_wait_ms > 512: # arbitrary limit to avoid excessive time_shifts
# raise ValueError("delta_time is too large")
if delta_ms > max_wait_ms * 10:
delta_ms = max_wait_ms * 10 # truncate time
for _ in range(floor(delta_ms / max_wait_ms)):
yield roundi(max_wait_ms / div)
leftover_time_shift = roundi((delta_ms % max_wait_ms) / div)
if leftover_time_shift > 0:
yield leftover_time_shift
def prog_data_to_token_data(
self, program: int, channel: int, note: int, velocity: float
) -> Optional[Tuple[int, int, int]]:
if channel == 9:
if self.cfg._ch10_bin_int == -1:
return None
return self.cfg._ch10_bin_int, note, self.velocity_to_bin(velocity)
instrument_bin = self.cfg._instrument_int_to_bin_int[program]
if instrument_bin != -1:
return instrument_bin, note, self.velocity_to_bin(velocity)
return None
def prog_data_list_to_token_data_list(
self, data: List[Tuple[int, int, int, float]]
) -> Iterator[Tuple[int, int, int]]:
for d in data:
token_data = self.prog_data_to_token_data(*d)
if token_data is not None:
yield token_data
def sort_token_data(
self, data: List[Tuple[int, int, int]]
) -> List[Tuple[int, int, int]]:
# ensure order is preserved for tokens with the same instrument, note
data = [(i, n, v, x) for x, (i, n, v) in enumerate(data)]
data.sort(key=lambda x: (x[0] != self.cfg._ch10_bin_int, x[0], x[1], x[3]))
return [(i, n, v) for i, n, v, _ in data]
def data_to_wait_tokens(self, delta_ms: float) -> List[str]:
if delta_ms == 0.0:
return []
return [self.format_wait_token(i) for i in self.delta_to_wait_ids(delta_ms)]
def wait_token_to_delta(self, token: str) -> float:
return self.cfg.max_wait_time / self.cfg.wait_events * int(token[1:])
def note_token_to_data(self, token: str) -> Tuple[int, int, int]:
instr_str, note_str, velocity_str = token.strip().split(":")
instr_bin = self.cfg._short_instrument_names_str_to_int[instr_str]
note = int(note_str, base=16)
velocity = self.bin_to_velocity(int(velocity_str, base=16))
return instr_bin, note, velocity
@dataclass
class AugmentValues:
instrument_bin_remap: Dict[int, int]
velocity_mod_factor: float
transpose_semitones: int
time_stretch_factor: float
@classmethod
def default(cls) -> "AugmentValues":
return cls(
instrument_bin_remap={},
velocity_mod_factor=1.0,
transpose_semitones=0,
time_stretch_factor=1.0,
)
@dataclass
class AugmentConfig:
# The number of times to augment each MIDI file. The dataset size will be multiplied by this number.
augment_data_factor: int
# A list of instrument names to randomly swap with each other.
instrument_mixups: List[List[str]]
# A list of percentages to change the note velocity by. 0.0 = no change. 0 is included by default.
velocity_mod_pct: List[float]
# A list of semitones to transpose by. 0 is included by default.
transpose_semitones: List[int]
# A list of percentages to stretch the tempo by. 0.0 = no stretch. 0 is included by default.
time_stretch_pct: List[float]
# Random seed to use for reproducibility.
seed: int
cfg: VocabConfig
def __post_init__(self):
self.validate()
if len(self.velocity_mod_pct) == 0:
self.velocity_mod_pct = [0.0]
if len(self.transpose_semitones) == 0:
self.transpose_semitones = [0]
if len(self.time_stretch_pct) == 0:
self.time_stretch_pct = [0.0]
self._instrument_mixups_int = [
[self.cfg._bin_str_to_int[i] for i in l if i in self.cfg._bin_str_to_int]
for l in self.instrument_mixups
]
self._instrument_mixups_int = [
l for l in self._instrument_mixups_int if len(l) > 0
] # remove empty lists
self._instrument_pool_assignments = {}
self._mixup_pools = []
for pool_i, mixup_list in enumerate(self._instrument_mixups_int):
pool = set()
for i in mixup_list:
pool.add(i)
self._instrument_pool_assignments[i] = pool_i
self._mixup_pools.append(pool)
def validate(self):
if self.augment_data_factor < 1:
raise ValueError("augment_data_factor must be at least 1")
used_instruments = set()
for mixup_list in self.instrument_mixups:
for n in mixup_list:
if n in used_instruments:
raise ValueError(f"Duplicate instrument name: {n}")
used_instruments.add(n)
@classmethod
def from_json(cls, path: str, cfg: VocabConfig):
with open(path, "r") as f:
config = json.load(f)
config["cfg"] = cfg
if "seed" not in config:
config["seed"] = random.randint(0, 2**32 - 1)
return cls(**config)
def get_augment_values(self, filename: str) -> Iterator[AugmentValues]:
# first yield default values
yield AugmentValues.default()
rng = random.Random(self.seed + hash(filename))
for _ in range(int(self.augment_data_factor - 1)):
# randomize order for each pool
randomized_pools = [list(pool) for pool in self._mixup_pools]
for pool in randomized_pools:
rng.shuffle(pool)
# distribute reassignments
instrument_bin_remap = {}
for i, pool in enumerate(randomized_pools):
for j, instrument in enumerate(pool):
instrument_bin_remap[instrument] = randomized_pools[i - 1][j]
yield AugmentValues(
instrument_bin_remap=instrument_bin_remap,
velocity_mod_factor=1.0 + rng.choice(self.velocity_mod_pct),
transpose_semitones=rng.choice(self.transpose_semitones),
time_stretch_factor=1.0 + rng.choice(self.time_stretch_pct),
)
def mix_volume(velocity: int, volume: int, expression: int) -> float:
return velocity * (volume / 127.0) * (expression / 127.0)
def convert_midi_to_str(
cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None
) -> str:
utils = VocabUtils(cfg)
if augment is None:
augment = AugmentValues.default()
# filter out unknown meta messages before merge (https://github.com/mido/mido/pull/286)
for i in range(len(mid.tracks)):
mid.tracks[i] = [msg for msg in mid.tracks[i] if msg.type != "unknown_meta"]
if len(mid.tracks) > 1:
mid.tracks = [mido.merge_tracks(mid.tracks)]
delta_time_ms = 0.0
tempo = 500000
channel_program = {i: 0 for i in range(16)}
channel_volume = {i: 127 for i in range(16)}
channel_expression = {
i: 127 for i in range(16)
} # unlikely to be useful. expression usually modifies an already played note.
channel_notes = {i: {} for i in range(16)}
channel_pedal_on = {i: False for i in range(16)}
channel_pedal_events = {
i: {} for i in range(16)
} # {channel: {(note, program) -> True}}
started_flag = False
output = ["<start>"]
token_data_buffer: List[
Tuple[int, int, int, float]
] = [] # need to sort notes between wait tokens
def flush_token_data_buffer():
nonlocal token_data_buffer, output, cfg, utils, augment
token_data = [
x for x in utils.prog_data_list_to_token_data_list(token_data_buffer)
]
if augment.instrument_bin_remap or augment.transpose_semitones:
# TODO put transpose in a real function
raw_transpose = (
lambda bin, n: n + augment.transpose_semitones
if bin != cfg._ch10_bin_int
else n
)
octave_shift_if_oob = (
lambda n: n + 12 if n < 0 else n - 12 if n >= cfg.note_events else n
)
# TODO handle ranges beyond 12
# octave_shift_if_oob = lambda n: 0 if n < 0 else (n - cfg.note_events) % 12 + cfg.note_events if n >= cfg.note_events else n
transpose = lambda bin, n: octave_shift_if_oob(raw_transpose(bin, n))
token_data = [
(augment.instrument_bin_remap.get(i, i), transpose(i, n), v)
for i, n, v in token_data
]
if cfg.do_token_sorting:
token_data = utils.sort_token_data(token_data)
if cfg.unrolled_tokens:
for t in token_data:
output += [
utils.format_unrolled_instrument_bin(t[0]),
utils.format_unrolled_note(t[1]),
utils.format_unrolled_velocity(t[2]),
]
else:
output += [utils.format_note_token(*t) for t in token_data]
token_data_buffer = []
def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer
is_token_valid = (
utils.prog_data_to_token_data(prog, chan, note, vel) is not None
)
if not is_token_valid:
return
if started_flag:
wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
if len(wait_tokens) > 0:
flush_token_data_buffer()
output += wait_tokens
delta_time_ms = 0.0
token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
started_flag = True
for msg in mid.tracks[0]:
time_ms = mido.tick2second(msg.time, mid.ticks_per_beat, tempo) * 1000.0
delta_time_ms += time_ms
t = msg.type
if msg.is_meta:
if t == "set_tempo":
tempo = msg.tempo * augment.time_stretch_factor
continue
def handle_note_off(ch, prog, n):
if channel_pedal_on[ch]:
channel_pedal_events[ch][(n, prog)] = True
else:
consume_note_program_data(prog, ch, n, 0)
if n in channel_notes[ch]:
del channel_notes[ch][n]
if t == "program_change":
channel_program[msg.channel] = msg.program
elif t == "note_on":
if msg.velocity == 0:
handle_note_off(msg.channel, channel_program[msg.channel], msg.note)
else:
if (msg.note, channel_program[msg.channel]) in channel_pedal_events[
msg.channel
]:
del channel_pedal_events[msg.channel][
(msg.note, channel_program[msg.channel])
]
consume_note_program_data(
channel_program[msg.channel],
msg.channel,
msg.note,
mix_volume(
msg.velocity,
channel_volume[msg.channel],
channel_expression[msg.channel],
),
)
channel_notes[msg.channel][msg.note] = True
elif t == "note_off":
handle_note_off(msg.channel, channel_program[msg.channel], msg.note)
elif t == "control_change":
if msg.control == 7 or msg.control == 39: # volume
channel_volume[msg.channel] = msg.value
elif msg.control == 11: # expression
channel_expression[msg.channel] = msg.value
elif msg.control == 64: # sustain pedal
channel_pedal_on[msg.channel] = msg.value >= 64
if not channel_pedal_on[msg.channel]:
for note, program in channel_pedal_events[msg.channel]:
handle_note_off(msg.channel, program, note)
channel_pedal_events[msg.channel] = {}
elif msg.control == 123: # all notes off
for channel in channel_notes.keys():
for note in list(channel_notes[channel]).copy():
handle_note_off(channel, channel_program[channel], note)
else:
pass
flush_token_data_buffer()
output.append("<end>")
return " ".join(output)
def generate_program_change_messages(cfg: VocabConfig):
for bin_name, channel in cfg.bin_channel_map.items():
if channel == 9:
continue
program = cfg._instrument_names_str_to_int[
cfg.bin_name_to_program_name[bin_name]
]
yield mido.Message("program_change", program=program, time=0, channel=channel)
yield mido.Message("program_change", program=0, time=0, channel=9)
@dataclass
class DecodeState:
total_time: float # milliseconds
delta_accum: float # milliseconds
current_bin: int
current_note: int
active_notes: Dict[Tuple[int, int], float] # { (channel, note): time started, ... }
def token_to_midi_message(
utils: VocabUtils, token: str, state: DecodeState, end_token_pause: float = 3.0
) -> Iterator[Tuple[Optional[mido.Message], DecodeState]]:
if state is None:
state = DecodeState(
total_time=0.0,
delta_accum=0.0,
current_bin=utils.cfg._short_instrument_names_str_to_int[
utils.cfg.short_instr_bin_names[0]
],
current_note=0,
active_notes={},
)
token = token.strip()
if not token:
yield None, state
return
if token == "<end>":
d = end_token_pause * 1000.0
state.delta_accum += d
state.total_time += d
if utils.cfg.decode_end_held_note_delay != 0.0:
# end held notes
for (channel, note), start_time in list(state.active_notes.items()).copy():
ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
state.delta_accum = 0.0
del state.active_notes[(channel, note)]
yield mido.Message(
"note_off", note=note, time=ticks, channel=channel
), state
yield None, state
return
if token.startswith("<"):
yield None, state
return
if utils.cfg.unrolled_tokens:
if token[0] == "t":
d = utils.wait_token_to_delta(token)
state.delta_accum += d
state.total_time += d
elif token[0] == "n":
state.current_note = int(token[1:], base=16)
elif token[0] == "i":
state.current_bin = utils.cfg._short_instrument_names_str_to_int[token[1:]]
elif token[0] == "v":
current_velocity = utils.bin_to_velocity(int(token[1:], base=16))
channel = utils.cfg.bin_channel_map[
utils.cfg.bin_instrument_names[state.current_bin]
]
ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
state.delta_accum = 0.0
if current_velocity > 0:
yield mido.Message(
"note_on",
note=state.current_note,
velocity=current_velocity,
time=ticks,
channel=channel,
), state
else:
yield mido.Message(
"note_off",
note=state.current_note,
velocity=0,
time=ticks,
channel=channel,
), state
else:
if token[0] == "t" and token[1].isdigit(): # wait token
d = utils.wait_token_to_delta(token)
state.delta_accum += d
state.total_time += d
if utils.cfg.decode_end_held_note_delay != 0.0:
# remove notes that have been held for too long
for (channel, note), start_time in list(
state.active_notes.items()
).copy():
if (
state.total_time - start_time
> utils.cfg.decode_end_held_note_delay * 1000.0
):
ticks = int(
mido.second2tick(state.delta_accum / 1000.0, 480, 500000)
)
state.delta_accum = 0.0
del state.active_notes[(channel, note)]
yield mido.Message(
"note_off", note=note, time=ticks, channel=channel
), state
return
else: # note token
bin, note, velocity = utils.note_token_to_data(token)
channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[bin]]
ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
state.delta_accum = 0.0
if velocity > 0:
if utils.cfg.decode_fix_repeated_notes:
if (channel, note) in state.active_notes:
del state.active_notes[(channel, note)]
yield mido.Message(
"note_off", note=note, time=ticks, channel=channel
), state
ticks = 0
state.active_notes[(channel, note)] = state.total_time
yield mido.Message(
"note_on", note=note, velocity=velocity, time=ticks, channel=channel
), state
return
else:
if (channel, note) in state.active_notes:
del state.active_notes[(channel, note)]
yield mido.Message(
"note_off", note=note, time=ticks, channel=channel
), state
return
yield None, state
def str_to_midi_messages(utils: VocabUtils, data: str) -> Iterator[mido.Message]:
state = None
for token in data.split(" "):
for msg, new_state in token_to_midi_message(utils, token, state):
state = new_state
if msg is not None:
yield msg
def convert_str_to_midi(
cfg: VocabConfig, data: str, meta_text: str = "Generated by MIDI-LLM-tokenizer"
) -> mido.MidiFile:
utils = VocabUtils(cfg)
mid = mido.MidiFile()
track = mido.MidiTrack()
mid.tracks.append(track)
tempo = 500000
if meta_text:
track.append(mido.MetaMessage("text", text=meta_text, time=0))
track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0))
for msg in generate_program_change_messages(cfg):
track.append(msg)
# data = data.replace("<start>", "").replace("<end>", "").replace("<pad>", "").strip()
for msg in str_to_midi_messages(utils, data):
track.append(msg)
track.append(mido.MetaMessage("end_of_track", time=0))
return mid

View File

@@ -0,0 +1,303 @@
{
"note_events": 128,
"wait_events": 125,
"max_wait_time": 1000,
"velocity_events": 128,
"velocity_bins": 12,
"velocity_exp": 0.5,
"do_token_sorting": true,
"unrolled_tokens": false,
"decode_end_held_note_delay": 5.0,
"decode_fix_repeated_notes": true,
"bin_instrument_names": [
"percussion",
"drum",
"tuba",
"marimba",
"bass",
"guitar",
"violin",
"trumpet",
"piano",
"sax",
"flute",
"lead",
"pad"
],
"ch10_instrument_bin_name": "percussion",
"program_name_to_bin_name": {
"Acoustic Grand Piano": "piano",
"Bright Acoustic Piano": "piano",
"Electric Grand Piano": "piano",
"Honky-tonk Piano": "piano",
"Electric Piano 1 (Rhodes Piano)": "piano",
"Electric Piano 2 (Chorused Piano)": "piano",
"Harpsichord": "piano",
"Clavinet": "piano",
"Celesta": "marimba",
"Glockenspiel": "marimba",
"Music Box": "marimba",
"Vibraphone": "marimba",
"Marimba": "marimba",
"Xylophone": "marimba",
"Tubular Bells": "marimba",
"Dulcimer (Santur)": "marimba",
"Drawbar Organ (Hammond)": "marimba",
"Percussive Organ": "piano",
"Rock Organ": "piano",
"Church Organ": "piano",
"Reed Organ": "piano",
"Accordion (French)": "piano",
"Harmonica": "piano",
"Tango Accordion (Band neon)": "piano",
"Acoustic Guitar (nylon)": "guitar",
"Acoustic Guitar (steel)": "guitar",
"Electric Guitar (jazz)": "guitar",
"Electric Guitar (clean)": "guitar",
"Electric Guitar (muted)": "guitar",
"Overdriven Guitar": "guitar",
"Distortion Guitar": "guitar",
"Guitar harmonics": "guitar",
"Acoustic Bass": "bass",
"Electric Bass (fingered)": "bass",
"Electric Bass (picked)": "bass",
"Fretless Bass": "bass",
"Slap Bass 1": "bass",
"Slap Bass 2": "bass",
"Synth Bass 1": "bass",
"Synth Bass 2": "bass",
"Violin": "violin",
"Viola": "violin",
"Cello": "bass",
"Contrabass": "bass",
"Tremolo Strings": "violin",
"Pizzicato Strings": "violin",
"Orchestral Harp": "violin",
"Timpani": "drum",
"String Ensemble 1 (strings)": "violin",
"String Ensemble 2 (slow strings)": "violin",
"SynthStrings 1": "violin",
"SynthStrings 2": "violin",
"Choir Aahs": "violin",
"Voice Oohs": "violin",
"Synth Voice": "violin",
"Orchestra Hit": "",
"Trumpet": "trumpet",
"Trombone": "tuba",
"Tuba": "tuba",
"Muted Trumpet": "trumpet",
"French Horn": "trumpet",
"Brass Section": "trumpet",
"SynthBrass 1": "trumpet",
"SynthBrass 2": "trumpet",
"Soprano Sax": "sax",
"Alto Sax": "sax",
"Tenor Sax": "sax",
"Baritone Sax": "sax",
"Oboe": "sax",
"English Horn": "trumpet",
"Bassoon": "sax",
"Clarinet": "sax",
"Piccolo": "flute",
"Flute": "flute",
"Recorder": "flute",
"Pan Flute": "flute",
"Blown Bottle": "flute",
"Shakuhachi": "flute",
"Whistle": "flute",
"Ocarina": "flute",
"Lead 1 (square wave)": "lead",
"Lead 2 (sawtooth wave)": "lead",
"Lead 3 (calliope)": "lead",
"Lead 4 (chiffer)": "lead",
"Lead 5 (charang)": "lead",
"Lead 6 (voice solo)": "violin",
"Lead 7 (fifths)": "lead",
"Lead 8 (bass + lead)": "lead",
"Pad 1 (new age Fantasia)": "pad",
"Pad 2 (warm)": "pad",
"Pad 3 (polysynth)": "pad",
"Pad 4 (choir space voice)": "violin",
"Pad 5 (bowed glass)": "pad",
"Pad 6 (metallic pro)": "pad",
"Pad 7 (halo)": "pad",
"Pad 8 (sweep)": "pad",
"FX 1 (rain)": "",
"FX 2 (soundtrack)": "",
"FX 3 (crystal)": "",
"FX 4 (atmosphere)": "",
"FX 5 (brightness)": "",
"FX 6 (goblins)": "",
"FX 7 (echoes, drops)": "",
"FX 8 (sci-fi, star theme)": "",
"Sitar": "guitar",
"Banjo": "guitar",
"Shamisen": "guitar",
"Koto": "guitar",
"Kalimba": "guitar",
"Bag pipe": "sax",
"Fiddle": "violin",
"Shanai": "sax",
"Tinkle Bell": "marimba",
"Agogo": "marimba",
"Steel Drums": "marimba",
"Woodblock": "marimba",
"Taiko Drum": "drum",
"Melodic Tom": "drum",
"Synth Drum": "drum",
"Reverse Cymbal": "",
"Guitar Fret Noise": "",
"Breath Noise": "",
"Seashore": "",
"Bird Tweet": "",
"Telephone Ring": "",
"Helicopter": "",
"Applause": "",
"Gunshot": ""
},
"bin_name_to_program_name": {
"piano": "Acoustic Grand Piano",
"marimba": "Marimba",
"drum": "Synth Drum",
"guitar": "Acoustic Guitar (steel)",
"bass": "Acoustic Bass",
"violin": "Violin",
"percussion": "",
"trumpet": "Trumpet",
"tuba": "Tuba",
"sax": "Tenor Sax",
"flute": "Flute",
"lead": "Lead 1 (square wave)",
"pad": "Pad 1 (new age Fantasia)"
},
"instrument_names": {
"0": "Acoustic Grand Piano",
"1": "Bright Acoustic Piano",
"2": "Electric Grand Piano",
"3": "Honky-tonk Piano",
"4": "Electric Piano 1 (Rhodes Piano)",
"5": "Electric Piano 2 (Chorused Piano)",
"6": "Harpsichord",
"7": "Clavinet",
"8": "Celesta",
"9": "Glockenspiel",
"10": "Music Box",
"11": "Vibraphone",
"12": "Marimba",
"13": "Xylophone",
"14": "Tubular Bells",
"15": "Dulcimer (Santur)",
"16": "Drawbar Organ (Hammond)",
"17": "Percussive Organ",
"18": "Rock Organ",
"19": "Church Organ",
"20": "Reed Organ",
"21": "Accordion (French)",
"22": "Harmonica",
"23": "Tango Accordion (Band neon)",
"24": "Acoustic Guitar (nylon)",
"25": "Acoustic Guitar (steel)",
"26": "Electric Guitar (jazz)",
"27": "Electric Guitar (clean)",
"28": "Electric Guitar (muted)",
"29": "Overdriven Guitar",
"30": "Distortion Guitar",
"31": "Guitar harmonics",
"32": "Acoustic Bass",
"33": "Electric Bass (fingered)",
"34": "Electric Bass (picked)",
"35": "Fretless Bass",
"36": "Slap Bass 1",
"37": "Slap Bass 2",
"38": "Synth Bass 1",
"39": "Synth Bass 2",
"40": "Violin",
"41": "Viola",
"42": "Cello",
"43": "Contrabass",
"44": "Tremolo Strings",
"45": "Pizzicato Strings",
"46": "Orchestral Harp",
"47": "Timpani",
"48": "String Ensemble 1 (strings)",
"49": "String Ensemble 2 (slow strings)",
"50": "SynthStrings 1",
"51": "SynthStrings 2",
"52": "Choir Aahs",
"53": "Voice Oohs",
"54": "Synth Voice",
"55": "Orchestra Hit",
"56": "Trumpet",
"57": "Trombone",
"58": "Tuba",
"59": "Muted Trumpet",
"60": "French Horn",
"61": "Brass Section",
"62": "SynthBrass 1",
"63": "SynthBrass 2",
"64": "Soprano Sax",
"65": "Alto Sax",
"66": "Tenor Sax",
"67": "Baritone Sax",
"68": "Oboe",
"69": "English Horn",
"70": "Bassoon",
"71": "Clarinet",
"72": "Piccolo",
"73": "Flute",
"74": "Recorder",
"75": "Pan Flute",
"76": "Blown Bottle",
"77": "Shakuhachi",
"78": "Whistle",
"79": "Ocarina",
"80": "Lead 1 (square wave)",
"81": "Lead 2 (sawtooth wave)",
"82": "Lead 3 (calliope)",
"83": "Lead 4 (chiffer)",
"84": "Lead 5 (charang)",
"85": "Lead 6 (voice solo)",
"86": "Lead 7 (fifths)",
"87": "Lead 8 (bass + lead)",
"88": "Pad 1 (new age Fantasia)",
"89": "Pad 2 (warm)",
"90": "Pad 3 (polysynth)",
"91": "Pad 4 (choir space voice)",
"92": "Pad 5 (bowed glass)",
"93": "Pad 6 (metallic pro)",
"94": "Pad 7 (halo)",
"95": "Pad 8 (sweep)",
"96": "FX 1 (rain)",
"97": "FX 2 (soundtrack)",
"98": "FX 3 (crystal)",
"99": "FX 4 (atmosphere)",
"100": "FX 5 (brightness)",
"101": "FX 6 (goblins)",
"102": "FX 7 (echoes, drops)",
"103": "FX 8 (sci-fi, star theme)",
"104": "Sitar",
"105": "Banjo",
"106": "Shamisen",
"107": "Koto",
"108": "Kalimba",
"109": "Bag pipe",
"110": "Fiddle",
"111": "Shanai",
"112": "Tinkle Bell",
"113": "Agogo",
"114": "Steel Drums",
"115": "Woodblock",
"116": "Taiko Drum",
"117": "Melodic Tom",
"118": "Synth Drum",
"119": "Reverse Cymbal",
"120": "Guitar Fret Noise",
"121": "Breath Noise",
"122": "Seashore",
"123": "Bird Tweet",
"124": "Telephone Ring",
"125": "Helicopter",
"126": "Applause",
"127": "Gunshot"
}
}

View File

@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod
import os
import pathlib
import copy
from typing import Dict, List, Tuple
import re
from typing import Dict, Iterable, List, Tuple, Union
from utils.log import quick_log
from fastapi import HTTPException
from pydantic import BaseModel, Field
import torch
import numpy as np
from rwkv_pip.utils import PIPELINE
from routes import state_cache
@@ -18,9 +18,10 @@ END_OF_LINE_DOUBLE = 535
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
class RWKV:
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str):
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
self.name = filename
@@ -29,101 +30,52 @@ class RWKV:
self.model_state = None
self.model_tokens = []
self.CHUNK_LEN = 256
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.5
self.penalty_alpha_presence = 0.4
self.penalty_alpha_frequency = 0.4
self.top_p = 0.3
self.top_k = 0
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
self.interface = ":"
if "world" in self.name.lower():
self.user = "Question"
self.bot = "Answer"
self.END_OF_LINE = 11
else:
self.user = "Bob"
self.bot = "Alice"
self.END_OF_LINE = 187
@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
pass
self.AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_REPEAT_TOKENS += dd
self.preload()
def preload(self):
interface = self.interface
user = self.user
bot = self.bot
preset_system = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if self.user == "Bob"
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=preset_system,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
@abstractmethod
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
pass
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens):
if "world" in self.name.lower():
return tokens
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
return tokens
@abstractmethod
def fix_tokens(self, tokens) -> List[int]:
pass
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
@abstractmethod
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
pass
while len(tokens) > 0:
out, self.model_state = self.model.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
out[self.END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out, token_len
@abstractmethod
def delta_postprocess(self, delta: str) -> str:
pass
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
if fast_mode:
embedding, token_len = self.fast_embedding(
embedding, token_len = self.__fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None
)
else:
self.model_state = None
self.model_tokens = []
_, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
embedding = self.model_state[-5].tolist()
embedding = self.model_state[-11].tolist()
embedding = (embedding / np.linalg.norm(embedding)).tolist()
return embedding, token_len
def fast_embedding(self, tokens: List[str], state):
def __fast_embedding(self, tokens: List[str], state):
import torch
tokens = [int(x) for x in tokens]
token_len = len(tokens)
self = self.model
@@ -260,7 +212,9 @@ The following is a coherent verbose detailed conversation between a girl named {
return state[0].tolist(), token_len
def generate(self, prompt: str, stop: str = None):
def generate(
self, prompt: str, stop: Union[str, List[str]] = None
) -> Iterable[Tuple[str, str, int, int]]:
quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None
delta_prompt = prompt
@@ -304,46 +258,60 @@ The following is a coherent verbose detailed conversation between a girl named {
completion_token_len = 0
response = ""
for i in range(self.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
self.adjust_forward_logits(logits, occurrence, i)
token = self.pipeline.sample_logits(
logits, temperature=self.temperature, top_p=self.top_p
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
)
if token == END_OF_TEXT:
yield response, "", prompt_token_len, completion_token_len
break
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
self.adjust_occurrence(occurrence, token)
logits, _ = self.run_rnn([token])
completion_token_len = completion_token_len + 1
delta: str = self.pipeline.decode(self.model_tokens[out_last:])
delta: str = self.delta_postprocess(
self.pipeline.decode(self.model_tokens[out_last:])
)
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
if type(stop) == str:
if stop in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
)
except HTTPException:
pass
response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len
break
except HTTPException:
pass
response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len
break
elif type(stop) == list:
stop_exist_regex = "|".join(stop)
matched = re.search(stop_exist_regex, response)
if matched:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
response = response.split(matched.group())[0]
yield response, "", prompt_token_len, completion_token_len
break
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try:
@@ -360,6 +328,153 @@ The following is a coherent verbose detailed conversation between a girl named {
yield response, delta, prompt_token_len, completion_token_len
class TextRWKV(AbstractRWKV):
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
super().__init__(model, strategy, tokens_path)
self.CHUNK_LEN = 256
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.3
self.top_k = 0
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
self.interface = ":"
if "world" in self.name.lower():
self.user = "Question"
self.bot = "Answer"
self.END_OF_LINE = 11
else:
self.user = "Bob"
self.bot = "Alice"
self.END_OF_LINE = 187
self.AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_REPEAT_TOKENS += dd
self.__preload()
def adjust_occurrence(self, occurrence: Dict, token: int):
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]:
if "world" in self.name.lower():
return tokens
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
while len(tokens) > 0:
out, self.model_state = self.model.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
out[self.END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return delta
def __preload(self):
interface = self.interface
user = self.user
bot = self.bot
preset_system = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if self.user == "Bob"
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=preset_system,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
class MusicRWKV(AbstractRWKV):
def __init__(self, model: str, strategy: str, tokens_path: str):
super().__init__(model, strategy, tokens_path)
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.8
self.top_k = 8
def adjust_occurrence(self, occurrence: Dict, token: int):
for n in occurrence:
occurrence[n] *= 0.997 #### decay repetition penalty
if token >= 128 or token == 127:
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
else:
occurrence[token] = 0.3 + (occurrence[token] if token in occurrence else 0)
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
logits[n] -= 0 + occurrence[n] * 0.5
logits[0] += (i - 2000) / 500 # try not to be too short or too long
logits[127] -= 1 # avoid "t125"
def fix_tokens(self, tokens) -> List[int]:
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
out, self.model_state = self.model.forward(tokens, self.model_state)
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return " " + delta
class ModelConfigBody(BaseModel):
max_tokens: int = Field(default=None, gt=0, le=102400)
temperature: float = Field(default=None, ge=0, le=2)
@@ -379,7 +494,7 @@ class ModelConfigBody(BaseModel):
}
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
if body.max_tokens is not None:
model.max_tokens_per_generation = body.max_tokens
if body.temperature is not None:
@@ -395,7 +510,7 @@ def set_rwkv_config(model: RWKV, body: ModelConfigBody):
model.penalty_alpha_frequency = body.frequency_penalty
def get_rwkv_config(model: RWKV) -> ModelConfigBody:
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
return ModelConfigBody(
max_tokens=model.max_tokens_per_generation,
temperature=model.temperature,

BIN
build/appicon.png vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 83 KiB

BIN
build/windows/icon.ico vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 167 KiB

After

Width:  |  Height:  |  Size: 175 KiB

View File

@@ -8,26 +8,32 @@ if [[ ${cnMirror} == 1 ]]; then
fi
fi
if dpkg -s "gcc" >/dev/null 2>&1; then
echo "gcc installed"
else
sudo apt -y install gcc
fi
if dpkg -s "python3-pip" >/dev/null 2>&1; then
echo "pip installed"
else
sudo apt install python3-pip
sudo apt -y install python3-pip
fi
if dpkg -s "ninja-build" >/dev/null 2>&1; then
echo "ninja installed"
else
sudo apt install ninja-build
sudo apt -y install ninja-build
fi
if dpkg -s "cuda" >/dev/null 2>&1; then
echo "cuda installed"
if dpkg -s "cuda" >/dev/null 2>&1 && dpkg -s "cuda" | grep Version | awk '{print $2}' | grep -q "12"; then
echo "cuda 12 installed"
else
wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
wget -N https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb
sudo dpkg -i cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb
sudo cp /var/cuda-repo-wsl-ubuntu-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/
wget -N https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda-repo-wsl-ubuntu-12-2-local_12.2.0-1_amd64.deb
sudo dpkg -i cuda-repo-wsl-ubuntu-12-2-local_12.2.0-1_amd64.deb
sudo cp /var/cuda-repo-wsl-ubuntu-12-2-local/cuda-*-keyring.gpg /usr/share/keyrings/
sudo apt-get update
sudo apt-get -y install cuda
fi

View File

@@ -17,11 +17,14 @@
"""Processing data for pretraining."""
import argparse
import multiprocessing
import os
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import argparse
import multiprocessing
import lm_dataformat as lmd
import numpy as np
@@ -240,4 +243,8 @@ def main():
if __name__ == "__main__":
main()
try:
main()
except Exception as e:
with open("error.txt", "w") as f:
f.write(str(e))

View File

@@ -5,49 +5,64 @@ from typing import Dict
import typing
import torch
if '-h' in sys.argv or '--help' in sys.argv:
print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
try:
if "-h" in sys.argv or "--help" in sys.argv:
print(
f"Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>"
)
if sys.argv[1] == '--use-gpu':
device = 'cuda'
lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
else:
device = 'cpu'
lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
if sys.argv[1] == "--use-gpu":
device = "cuda"
lora_alpha, base_model, lora, output = (
float(sys.argv[2]),
sys.argv[3],
sys.argv[4],
sys.argv[5],
)
else:
device = "cpu"
lora_alpha, base_model, lora, output = (
float(sys.argv[1]),
sys.argv[2],
sys.argv[3],
sys.argv[4],
)
with torch.no_grad():
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location="cpu")
# merge LoRA-only slim checkpoint into the main weights
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location="cpu")
for k in w_lora.keys():
w[k] = w_lora[k]
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
# merge LoRA weights
keys = list(w.keys())
for k in keys:
if k.endswith(".weight"):
prefix = k[: -len(".weight")]
lora_A = prefix + ".lora_A"
lora_B = prefix + ".lora_B"
if lora_A in keys:
assert lora_B in keys
print(f"merging {lora_A} and {lora_B} into {k}")
assert w[lora_B].shape[1] == w[lora_A].shape[0]
lora_r = w[lora_B].shape[1]
w[k] = w[k].to(device=device)
w[lora_A] = w[lora_A].to(device=device)
w[lora_B] = w[lora_B].to(device=device)
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
output_w[k] = w[k].to(device="cpu", copy=True)
del w[k]
del w[lora_A]
del w[lora_B]
continue
with torch.no_grad():
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
# merge LoRA-only slim checkpoint into the main weights
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
for k in w_lora.keys():
w[k] = w_lora[k]
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
# merge LoRA weights
keys = list(w.keys())
for k in keys:
if k.endswith('.weight'):
prefix = k[:-len('.weight')]
lora_A = prefix + '.lora_A'
lora_B = prefix + '.lora_B'
if lora_A in keys:
assert lora_B in keys
print(f'merging {lora_A} and {lora_B} into {k}')
assert w[lora_B].shape[1] == w[lora_A].shape[0]
lora_r = w[lora_B].shape[1]
w[k] = w[k].to(device=device)
w[lora_A] = w[lora_A].to(device=device)
w[lora_B] = w[lora_B].to(device=device)
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
output_w[k] = w[k].to(device='cpu', copy=True)
if "lora" not in k:
print(f"retaining {k}")
output_w[k] = w[k].clone()
del w[k]
del w[lora_A]
del w[lora_B]
continue
if 'lora' not in k:
print(f'retaining {k}')
output_w[k] = w[k].clone()
del w[k]
torch.save(output_w, output)
torch.save(output_w, output)
except Exception as e:
with open("error.txt", "w") as f:
f.write(str(e))

203
finetune/lora/train.py vendored
View File

@@ -50,52 +50,84 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument(
"--wandb", default="", type=str
) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument(
"--vocab_size", default=0, type=int
) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument(
"--epoch_steps", default=1000, type=int
) # a mini "epoch" has [epoch_steps] steps
parser.add_argument(
"--epoch_count", default=500, type=int
) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument(
"--epoch_begin", default=0, type=int
) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument(
"--epoch_save", default=5, type=int
) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument(
"--micro_bsz", default=12, type=int
) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument(
"--pre_ffn", default=0, type=int
) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument(
"--tiny_att_layer", default=-999, type=int
) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument(
"--lr_init", default=6e-4, type=float
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument(
"--warmup_steps", default=0, type=int
) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument(
"--beta2", default=0.99, type=float
) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument(
"--grad_cp", default=0, type=int
) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument(
"--my_pile_shift", default=-1, type=int
) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
parser.add_argument(
"--layerwise_lr", default=1, type=int
) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument(
"--ds_bucket_mb", default=200, type=int
) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip", default="x", type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
parser.add_argument("--my_img_encoder", default="x", type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
@@ -104,7 +136,7 @@ if __name__ == "__main__":
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser.add_argument("--my_testing", default="", type=str)
parser.add_argument("--lora", action="store_true")
parser.add_argument("--lora_load", default="", type=str)
@@ -122,18 +154,26 @@ if __name__ == "__main__":
import numpy as np
import torch
from torch.utils.data import DataLoader
if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
print(
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
* 3
)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
warnings.filterwarnings(
"ignore", ".*Consider increasing the value of the `num_workers` argument*"
)
warnings.filterwarnings(
"ignore", ".*The progress bar already tracks a metric with the*"
)
# os.environ["WDS_SHOW_SEED"] = "1"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
@@ -158,7 +198,9 @@ if __name__ == "__main__":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
else:
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
args.run_name = (
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
)
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
@@ -240,24 +282,40 @@ if __name__ == "__main__":
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
assert args.data_type in [
"utf-8",
"utf-16le",
"numpy",
"binidx",
"dummy",
"wds_img",
"uint16",
]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
rank_zero_info(
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
)
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
rank_zero_info(
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
)
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
rank_zero_info(
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
)
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
if args.lora and args.grad_cp == 1:
print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it')
print(
"!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it"
)
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
@@ -284,20 +342,22 @@ if __name__ == "__main__":
train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size
if args.data_type == 'wds_img':
if args.data_type == "wds_img":
from src.model_img import RWKV_IMG
assert args.lora, "LoRA not yet supported for RWKV_IMG"
model = RWKV_IMG(args)
else:
from src.model import RWKV, LORA_CONFIG, LoraLinear
if args.lora:
assert args.lora_r > 0, "LoRA should have its `r` > 0"
LORA_CONFIG["r"] = args.lora_r
LORA_CONFIG["alpha"] = args.lora_alpha
LORA_CONFIG["dropout"] = args.lora_dropout
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(','))
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
enable_time_finetune = "time" in LORA_CONFIG["parts"]
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
model = RWKV(args)
# only train lora parameters
if args.lora:
@@ -305,20 +365,24 @@ if __name__ == "__main__":
for name, module in model.named_modules():
# have to check param name since it may have been wrapped by torchscript
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
print(f' LoRA training module {name}')
print(f" LoRA training module {name}")
for pname, param in module.named_parameters():
param.requires_grad = 'lora_' in pname
elif enable_ln_finetune and '.ln' in name:
print(f' LoRA additionally training module {name}')
param.requires_grad = "lora_" in pname
elif enable_ln_finetune and ".ln" in name:
print(f" LoRA additionally training module {name}")
for param in module.parameters():
param.requires_grad = True
elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()):
elif enable_time_finetune and any(
n.startswith("time") for n, _ in module.named_parameters()
):
for pname, param in module.named_parameters():
if pname.startswith("time"):
print(f' LoRA additionally training parameter {pname}')
print(f" LoRA additionally training parameter {pname}")
param.requires_grad = True
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
if (
len(args.load_model) == 0 or args.my_pile_stage == 1
): # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
@@ -326,6 +390,7 @@ if __name__ == "__main__":
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict, strict=(not args.lora))
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
@@ -337,36 +402,50 @@ if __name__ == "__main__":
args.epoch_begin = max_p + 1
rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict, strict=(not args.lora))
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict, strict=(not args.lora))
# If using LoRA, the LoRA keys might be missing in the original model
model.load_state_dict(load_dict, strict=(not args.lora))
# model.load_state_dict(load_dict, strict=(not args.lora))
if os.path.isfile(args.lora_load):
model.load_state_dict(torch.load(args.lora_load, map_location="cpu"),
strict=False)
model.load_state_dict(
torch.load(args.lora_load, map_location="cpu"), strict=False
)
trainer: Trainer = Trainer.from_argparse_args(
args,
callbacks=[train_callback(args)],
)
if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8):
if 'I_KNOW_WHAT_IM_DOING' in os.environ:
if (
args.lr_init > 1e-4
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
):
if "I_KNOW_WHAT_IM_DOING" in os.environ:
if trainer.global_rank == 0:
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(
f" WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
else:
if trainer.global_rank == 0:
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
print(f' Unless you are sure this is what you want, adjust them accordingly')
print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(
f" ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
)
print(
f" Unless you are sure this is what you want, adjust them accordingly"
)
print(
f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")'
)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
exit(0)
if trainer.global_rank == 0:
@@ -379,10 +458,22 @@ if __name__ == "__main__":
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
args.ds_bucket_mb * 1000 * 1000
)
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
args.ds_bucket_mb * 1000 * 1000
)
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
data_loader = DataLoader(
train_data,
shuffle=False,
pin_memory=True,
batch_size=args.micro_bsz,
num_workers=1,
persistent_workers=False,
drop_last=True,
)
trainer.fit(model, data_loader)

File diff suppressed because it is too large Load Diff

View File

@@ -11,11 +11,13 @@
"dependencies": {
"@fluentui/react-components": "^9.20.0",
"@fluentui/react-icons": "^2.0.201",
"@magenta/music": "^1.23.1",
"@microsoft/fetch-event-source": "^2.0.1",
"@primer/octicons-react": "^19.1.0",
"chart.js": "^4.3.0",
"classnames": "^2.3.2",
"github-markdown-css": "^5.2.0",
"html-midi-player": "^1.5.0",
"i18next": "^22.4.15",
"mobx": "^6.9.0",
"mobx-react-lite": "^3.4.3",

View File

@@ -104,7 +104,7 @@
"Supported custom cuda file not found": "没有找到支持的自定义cuda文件",
"Failed to copy custom cuda file": "自定义cuda文件复制失败",
"Downloading update, please wait. If it is not completed, please manually download the program from GitHub and replace the original program.": "正在下载更新请等待。如果一直未完成请从Github手动下载并覆盖原程序",
"Completion": "补全",
"Completion": "续写",
"Parameters": "参数",
"Stop Sequences": "停止词",
"When this content appears in the response result, the generation will end.": "响应结果出现该内容时就结束生成",
@@ -113,17 +113,17 @@
"Writer": "写作",
"Translator": "翻译",
"Catgirl": "猫娘",
"Explain Code": "代码解释",
"Code Generation": "代码生成",
"Werewolf": "狼人杀",
"Instruction": "指令",
"Blank": "空白",
"The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n": "《背影》\n我与父亲不相见已二年余了我最不能忘记的是他的背影。\n那年冬天祖母死了父亲的差使也交卸了正是祸不单行的日子。我从北京到徐州打算",
"The following is a conversation between a cat girl and her owner. The cat girl is a humanized creature that behaves like a cat but is humanoid. At the end of each sentence in the dialogue, she will add \"Meow~\". In the following content, Bob represents the owner and Alice represents the cat girl.\n\nBob: Hello.\n\nAlice: I'm here, meow~.\n\nBob: Can you tell jokes?": "以下是一位猫娘的主人和猫娘的对话内容,猫娘是一种拟人化的生物,其行为似猫但类人,在每一句对话末尾都会加上\"喵~\"。以下内容中,Bob代表主人Alice代表猫娘。\n\nBob: 你好\n\nAlice: 主人我在哦,喵~\n\nBob: 你会讲笑话吗?",
"The following is a conversation between a cat girl and her owner. The cat girl is a humanized creature that behaves like a cat but is humanoid. At the end of each sentence in the dialogue, she will add \"Meow~\". In the following content, User represents the owner and Assistant represents the cat girl.\n\nUser: Hello.\n\nAssistant: I'm here, meow~.\n\nUser: Can you tell jokes?": "以下是一位猫娘的主人和猫娘的对话内容,猫娘是一种拟人化的生物,其行为似猫但类人,在每一句对话末尾都会加上\"喵~\"。以下内容中,User代表主人Assistant代表猫娘。\n\nUser: 你好\n\nAssistant: 主人我在哦,喵~\n\nUser: 你会讲笑话吗?",
"When response finished, inject this content.": "响应结束时,插入此内容到末尾",
"Inject start text": "起始注入文本",
"Inject end text": "结尾注入文本",
"Before the response starts, inject this content.": "响应开始前,在开头插入此内容",
"There is currently a game of Werewolf with six players, including a Seer (who can check identities at night), two Werewolves (who can choose someone to kill at night), a Bodyguard (who can choose someone to protect at night), two Villagers (with no special abilities), and a game host. Bob will play as Player 1, Alice will play as Players 2-6 and the game host, and they will begin playing together. Every night, the host will ask Bob for his action and simulate the actions of the other players. During the day, the host will oversee the voting process and ask Bob for his vote. \n\nAlice: Next, I will act as the game host and assign everyone their roles, including randomly assigning yours. Then, I will simulate the actions of Players 2-6 and let you know what happens each day. Based on your assigned role, you can tell me your actions and I will let you know the corresponding results each day.\n\nBob: Okay, I understand. Let's begin. Please assign me a role. Am I the Seer, Werewolf, Villager, or Bodyguard?\n\nAlice: You are the Seer. Now that night has fallen, please choose a player to check his identity.\n\nBob: Tonight, I want to check Player 2 and find out his role.": "现在有一场六人狼人杀游戏,包括一名预言家(可以在夜晚查验身份),两名狼人(可以在夜晚选择杀人),一名守卫(可以在夜晚选择要守护的人),两名平民(无技能),一名主持人,以下内容中Bob将扮演其中的1号玩家Alice来扮演2-6号玩家以及主持人并开始与Bob进行游戏,主持人每晚都会询问Bob的行动,并模拟其他人的行动,在白天则要主持投票,并同样询问Bob投票对象,公布投票结果。\n\nAlice: 接下来我将首先作为主持人进行角色分配并给你赋予随机的角色之后我将模拟2-6号玩家进行行动告知你每天的动态根据你被分配的角色你可以回复我你做的行动我会告诉你每天对应的结果\n\nBob: 好的,我明白了,那么开始吧。请先给我一个角色身份。我是预言家,狼人,平民,守卫中的哪一个呢?\n\nAlice: 你的身份是预言家。现在夜晚降临,请选择你要查验的玩家。\n\nBob: 今晚我要验2号玩家他是什么身份",
"There is currently a game of Werewolf with six players, including a Seer (who can check identities at night), two Werewolves (who can choose someone to kill at night), a Bodyguard (who can choose someone to protect at night), two Villagers (with no special abilities), and a game host. User will play as Player 1, Assistant will play as Players 2-6 and the game host, and they will begin playing together. Every night, the host will ask User for his action and simulate the actions of the other players. During the day, the host will oversee the voting process and ask User for his vote. \n\nAssistant: Next, I will act as the game host and assign everyone their roles, including randomly assigning yours. Then, I will simulate the actions of Players 2-6 and let you know what happens each day. Based on your assigned role, you can tell me your actions and I will let you know the corresponding results each day.\n\nUser: Okay, I understand. Let's begin. Please assign me a role. Am I the Seer, Werewolf, Villager, or Bodyguard?\n\nAssistant: You are the Seer. Now that night has fallen, please choose a player to check his identity.\n\nUser: Tonight, I want to check Player 2 and find out his role.": "现在有一场六人狼人杀游戏,包括一名预言家(可以在夜晚查验身份),两名狼人(可以在夜晚选择杀人),一名守卫(可以在夜晚选择要守护的人),两名平民(无技能),一名主持人,以下内容中User将扮演其中的1号玩家Assistant来扮演2-6号玩家以及主持人并开始与User进行游戏,主持人每晚都会询问User的行动,并模拟其他人的行动,在白天则要主持投票,并同样询问User投票对象,公布投票结果。\n\nAssistant: 接下来我将首先作为主持人进行角色分配并给你赋予随机的角色之后我将模拟2-6号玩家进行行动告知你每天的动态根据你被分配的角色你可以回复我你做的行动我会告诉你每天对应的结果\n\nUser: 好的,我明白了,那么开始吧。请先给我一个角色身份。我是预言家,狼人,平民,守卫中的哪一个呢?\n\nAssistant: 你的身份是预言家。现在夜晚降临,请选择你要查验的玩家。\n\nUser: 今晚我要验2号玩家他是什么身份",
"Writer, Translator, Role-playing": "写作,翻译,角色扮演",
"Chinese Kongfu": "情境冒险",
"Allow external access to the API (service must be restarted)": "允许外部访问API (必须重启服务)",
@@ -153,7 +153,7 @@
"Restart the app to apply DPI Scaling.": "重启应用以使显示缩放生效",
"Restart": "重启",
"API Chat Model Name": "API聊天模型名",
"API Completion Model Name": "API补全模型名",
"API Completion Model Name": "API续写模型名",
"Localhost": "本地",
"Retry": "重试",
"Delete": "删除",
@@ -195,7 +195,7 @@
"Please convert data first.": "请先转换数据",
"Ubuntu is not installed, do you want to install it?": "Ubuntu未安装是否安装",
"Install Ubuntu": "安装Ubuntu",
"Please install Ubuntu using Microsoft Store": "请用Microsoft Store安装Ubuntu",
"Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button": "请用Microsoft Store安装Ubuntu安装完成后点击Microsoft Store界面的“打开”按钮然后点击“训练”按钮",
"WSL is not enabled, do you want to enable it?": "WSL未启用是否启用",
"Enable WSL": "启用WSL",
"After installation, please restart your computer to enable WSL": "安装完成后请重启电脑以启用WSL",
@@ -221,5 +221,24 @@
"Pre-FFN": "前馈网络预处理",
"None": "空",
"Merge model successfully": "合并模型成功",
"Convert Data successfully": "数据转换成功"
"Convert Data successfully": "数据转换成功",
"Please select a LoRA model": "请选择一个LoRA模型",
"You are using sample data for training. For formal training, please make sure to create your own jsonl file.": "你正在使用示例数据训练对于正式训练场合请务必创建你自己的jsonl训练数据",
"WSL is not running, please retry. If it keeps happening, it means you may be using an outdated version of WSL, run \"wsl --update\" to update.": "WSL没有运行请重试。如果一直出现此错误意味着你可能正在使用旧版本的WSL请在cmd执行\"wsl --update\"以更新",
"Memory is not enough, try to increase the virtual memory or use a smaller base model.": "内存不足,尝试增加虚拟内存,或使用一个更小规模的基底模型",
"VRAM is not enough": "显存不足",
"Training data is not enough, reduce context length or add more data for training": "训练数据不足,请减小上下文长度或增加训练数据",
"You are using WSL 1 for training, please upgrade to WSL 2. e.g. Run \"wsl --set-version Ubuntu-22.04 2\"": "你正在使用WSL 1进行训练请升级到WSL 2。例如运行\"wsl --set-version Ubuntu-22.04 2\"",
"Matched CUDA is not installed": "未安装匹配的CUDA",
"Failed to convert data": "数据转换失败",
"Failed to merge model": "合并模型失败",
"The data path should be a directory or a file in jsonl format (more formats will be supported in the future).\n\nWhen you provide a directory path, all the txt files within that directory will be automatically converted into training data. This is commonly used for large-scale training in writing, code generation, or knowledge bases.\n\nThe jsonl format file can be referenced at https://github.com/Abel2076/json2binidx_tool/blob/main/sample.jsonl.\nYou can also write it similar to OpenAI's playground format, as shown in https://platform.openai.com/playground/p/default-chat.\nEven for multi-turn conversations, they must be written in a single line using `\\n` to indicate line breaks. If they are different dialogues or topics, they should be written in separate lines.": "数据路径必须是一个文件夹或者jsonl格式文件 (未来会支持更多格式)\n\n当你填写的路径是一个文件夹时该文件夹内的所有txt文件会被自动转换为训练数据通常这用于大批量训练写作代码生成或知识库\n\njsonl文件的格式参考 https://github.com/Abel2076/json2binidx_tool/blob/main/sample.jsonl\n你也可以仿照openai的playground编写参考 https://platform.openai.com/playground/p/default-chat\n即使是多轮对话也必须写在一行用`\\n`表示换行,如果是不同对话或主题,则另起一行",
"Size mismatch for blocks. You are attempting to continue training from the LoRA model, but it does not match the base model. Please set LoRA model to None.": "尺寸不匹配块。你正在尝试从LoRA模型继续训练但该LoRA模型与基底模型不匹配请将LoRA模型设为空",
"Instruction: Write a story using the following information\n\nInput: A man named Alex chops a tree down\n\nResponse:": "Instruction: Write a story using the following information\n\nInput: 艾利克斯砍倒了一棵树\n\nResponse:",
"Composition": "作曲",
"Use Local Sound Font": "使用本地音色资源",
"Auto Play At The End": "结束时自动播放",
"No File to save": "无文件可保存",
"File Saved": "文件已保存",
"Failed to load local sound font, please check if the files exist - assets/sound-font": "加载本地音色资源失败,请检查文件是否存在 - assets/sound-font"
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

View File

@@ -11,6 +11,7 @@ import {
} from '@fluentui/react-components';
import { ToolTipButton } from './ToolTipButton';
import { useTranslation } from 'react-i18next';
import MarkdownRender from './MarkdownRender';
export const DialogButton: FC<{
text?: string | null
@@ -19,12 +20,13 @@ export const DialogButton: FC<{
className?: string,
title: string,
contentText: string,
onConfirm: () => void,
markdown?: boolean,
onConfirm?: () => void,
size?: 'small' | 'medium' | 'large',
shape?: 'rounded' | 'circular' | 'square',
appearance?: 'secondary' | 'primary' | 'outline' | 'subtle' | 'transparent',
}> = ({
text, icon, tooltip, className, title, contentText,
text, icon, tooltip, className, title, contentText, markdown,
onConfirm, size, shape, appearance
}) => {
const { t } = useTranslation();
@@ -41,7 +43,11 @@ export const DialogButton: FC<{
<DialogBody>
<DialogTitle>{title}</DialogTitle>
<DialogContent>
{contentText}
{
markdown ?
<MarkdownRender>{contentText}</MarkdownRender> :
contentText
}
</DialogContent>
<DialogActions>
<DialogTrigger disableButtonEnhancement>

View File

@@ -4,7 +4,7 @@ import { useTranslation } from 'react-i18next';
import { ArrowReset20Regular } from '@fluentui/react-icons';
import commonStore from '../stores/commonStore';
import { defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultModelConfigs';
import { defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultConfigs';
export const ResetConfigsButton: FC<{ afterConfirm?: () => void }> = ({ afterConfirm }) => {
const { t } = useTranslation();

View File

@@ -1,23 +1,16 @@
import React, { FC, MouseEventHandler, ReactElement } from 'react';
import commonStore, { ModelStatus } from '../stores/commonStore';
import {
AddToDownloadList,
CopyFile,
DepCheck,
FileExists,
InstallPyDep,
StartServer
} from '../../wailsjs/go/backend_golang/App';
import { AddToDownloadList, CopyFile, FileExists, StartServer } from '../../wailsjs/go/backend_golang/App';
import { Button } from '@fluentui/react-components';
import { observer } from 'mobx-react-lite';
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
import { toast } from 'react-toastify';
import { getStrategy, getSupportedCustomCudaFile, toastWithButton } from '../utils';
import { checkDependencies, getStrategy, getSupportedCustomCudaFile, toastWithButton } from '../utils';
import { useTranslation } from 'react-i18next';
import { ToolTipButton } from './ToolTipButton';
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
import { useNavigate } from 'react-router';
import { BrowserOpenURL, WindowShow } from '../../wailsjs/runtime/runtime';
import { WindowShow } from '../../wailsjs/runtime/runtime';
const mainButtonText = {
[ModelStatus.Offline]: 'Run',
@@ -57,52 +50,9 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
return;
}
if (!commonStore.depComplete) {
let depErrorMsg = '';
await DepCheck(commonStore.settings.customPythonPath).catch((e) => {
depErrorMsg = e.message || e;
WindowShow();
if (depErrorMsg === 'python zip not found') {
toastWithButton(t('Python target not found, would you like to download it?'), t('Download'), () => {
toastWithButton(`${t('Downloading')} Python`, t('Check'), () => {
navigate({ pathname: '/downloads' });
}, { autoClose: 3000 });
AddToDownloadList('python-3.10.11-embed-amd64.zip', 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip');
});
} else if (depErrorMsg.includes('DepCheck Error')) {
if (depErrorMsg.includes('vc_redist')) {
toastWithButton(t('Microsoft Visual C++ Redistributable is not installed, would you like to download it?'), t('Download'), () => {
BrowserOpenURL('https://aka.ms/vs/16/release/vc_redist.x64.exe');
});
} else {
toast(depErrorMsg, { type: 'info', position: 'bottom-left' });
if (commonStore.platform != 'linux')
toastWithButton(t('Python dependencies are incomplete, would you like to install them?'), t('Install'), () => {
InstallPyDep(commonStore.settings.customPythonPath, commonStore.settings.cnMirror).catch((e) => {
const errMsg = e.message || e;
toast(t('Error') + ' - ' + errMsg, { type: 'error' });
});
setTimeout(WindowShow, 1000);
}, {
autoClose: 8000
});
else
toastWithButton(t('On Linux system, you must manually install python dependencies.'), t('Check'), () => {
BrowserOpenURL('https://github.com/josStorer/RWKV-Runner/blob/master/build/linux/Readme_Install.txt');
});
}
} else {
toast(depErrorMsg, { type: 'error' });
}
});
if (depErrorMsg) {
commonStore.setStatus({ status: ModelStatus.Offline });
return;
}
commonStore.setDepComplete(true);
if (commonStore.platform === 'windows')
CopyFile('./backend-python/wkv_cuda_utils/wkv_cuda_model.py', './py310/Lib/site-packages/rwkv/model.py');
}
const ok = await checkDependencies(navigate);
if (!ok)
return;
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
@@ -199,9 +149,16 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
}).then(async (r) => {
if (r.ok) {
commonStore.setStatus({ status: ModelStatus.Working });
toastWithButton(t('Startup Completed'), t('Chat'), () => {
navigate({ pathname: '/chat' });
}, { type: 'success', autoClose: 3000 });
let buttonNameMap = {
'novel': 'Completion',
'midi': 'Composition'
};
let buttonName = 'Chat';
buttonName = Object.entries(buttonNameMap).find(([key, value]) => modelName.toLowerCase().includes(key))?.[1] || buttonName;
const buttonFn = () => {
navigate({ pathname: '/' + buttonName.toLowerCase() });
};
toastWithButton(t('Startup Completed'), t(buttonName), buttonFn, { type: 'success', autoClose: 3000 });
} else if (r.status === 304) {
toast(t('Loading Model'), { type: 'info' });
} else {

View File

@@ -6,6 +6,7 @@ import App from './App';
import { HashRouter } from 'react-router-dom';
import { startup } from './startup';
import './_locales/i18n-react';
import 'html-midi-player';
import { WindowShow } from '../wailsjs/runtime';
startup().then(() => {

View File

@@ -7,7 +7,7 @@ import { v4 as uuid } from 'uuid';
import classnames from 'classnames';
import { fetchEventSource } from '@microsoft/fetch-event-source';
import { KebabHorizontalIcon, PencilIcon, SyncIcon, TrashIcon } from '@primer/octicons-react';
import logo from '../assets/images/logo.jpg';
import logo from '../assets/images/logo.png';
import MarkdownRender from '../components/MarkdownRender';
import { ToolTipButton } from '../components/ToolTipButton';
import { ArrowCircleUp28Regular, Delete28Regular, RecordStop28Regular, Save28Regular } from '@fluentui/react-icons';
@@ -184,7 +184,9 @@ const ChatPanel: FC = observer(() => {
const bodyRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLTextAreaElement>(null);
const mq = useMediaQuery('(min-width: 640px)');
const port = commonStore.getCurrentModelConfig().apiParameters.apiPort;
const currentConfig = commonStore.getCurrentModelConfig();
const apiParams = currentConfig.apiParameters;
const port = apiParams.apiPort;
let lastMessageId: string;
let generating: boolean = false;
@@ -308,12 +310,14 @@ const ChatPanel: FC = observer(() => {
body: JSON.stringify({
messages,
stream: true,
model: commonStore.settings.apiChatModelName // 'gpt-3.5-turbo'
model: commonStore.settings.apiChatModelName, // 'gpt-3.5-turbo'
temperature: apiParams.temperature,
top_p: apiParams.topP
}),
signal: chatSseController?.signal,
onmessage(e) {
scrollToBottom();
if (e.data === '[DONE]') {
if (e.data.trim() === '[DONE]') {
commonStore.conversation[answerId!].done = true;
commonStore.conversation[answerId!].content = commonStore.conversation[answerId!].content.trim();
commonStore.setConversation(commonStore.conversation);
@@ -421,7 +425,7 @@ const ChatPanel: FC = observer(() => {
}
});
OpenSaveFileDialog('*.md', 'conversation.md', savedContent).then((path) => {
OpenSaveFileDialog('*.txt', 'conversation.txt', savedContent).then((path) => {
if (path)
toastWithButton(t('Conversation Saved'), t('Open'), () => {
OpenFileFolder(path, false);

View File

@@ -13,6 +13,7 @@ import { DialogButton } from '../components/DialogButton';
import { PresetsButton } from './PresetsManager/PresetsButton';
import { ToolTipButton } from '../components/ToolTipButton';
import { ArrowSync20Regular } from '@fluentui/react-icons';
import { defaultPresets } from './defaultConfigs';
export type CompletionParams = Omit<ApiParameters, 'apiPort'> & {
stop: string,
@@ -26,113 +27,6 @@ export type CompletionPreset = {
params: CompletionParams
}
export const defaultPresets: CompletionPreset[] = [{
name: 'Writer',
prompt: 'The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n',
params: {
maxResponseToken: 500,
temperature: 1.2,
topP: 0.5,
presencePenalty: 0.4,
frequencyPenalty: 0.4,
stop: '\\n\\nBob',
injectStart: '',
injectEnd: ''
}
}, {
name: 'Translator',
prompt: 'Translate this into Chinese.\n\nEnglish: What rooms do you have available?',
params: {
maxResponseToken: 500,
temperature: 1,
topP: 0.3,
presencePenalty: 0.4,
frequencyPenalty: 0.4,
stop: '\\nEnglish',
injectStart: '\\nChinese: ',
injectEnd: '\\nEnglish: '
}
}, {
name: 'Catgirl',
prompt: 'The following is a conversation between a cat girl and her owner. The cat girl is a humanized creature that behaves like a cat but is humanoid. At the end of each sentence in the dialogue, she will add \"Meow~\". In the following content, Bob represents the owner and Alice represents the cat girl.\n\nBob: Hello.\n\nAlice: I\'m here, meow~.\n\nBob: Can you tell jokes?',
params: {
maxResponseToken: 500,
temperature: 1.2,
topP: 0.5,
presencePenalty: 0.4,
frequencyPenalty: 0.4,
stop: '\\n\\nBob',
injectStart: '\\n\\nAlice: ',
injectEnd: '\\n\\nBob: '
}
}, {
name: 'Chinese Kongfu',
prompt: 'Bob: 请你扮演一个文本冒险游戏,我是游戏主角。这是一个玄幻修真世界,有四大门派。我输入我的行动,请你显示行动结果,并具体描述环境。我的第一个行动是“醒来”,请开始故事。',
params: {
maxResponseToken: 500,
temperature: 1.1,
topP: 0.7,
presencePenalty: 0.3,
frequencyPenalty: 0.3,
stop: '\\n\\nBob',
injectStart: '\\n\\nAlice: ',
injectEnd: '\\n\\nBob: '
}
}, {
// }, {
// name: 'Explain Code',
// prompt: 'export async function startup() {\n FileExists(\'cache.json\').then((exists) => {\n if (exists)\n downloadProgramFiles();\n else {\n deleteDynamicProgramFiles().then(downloadProgramFiles);\n }\n });\n EventsOn(\'downloadList\', (data) => {\n if (data)\n commonStore.setDownloadList(data);\n });\n\n initCache().then(initRemoteText);\n\n await initConfig();\n\n if (commonStore.settings.autoUpdatesCheck) // depends on config settings\n checkUpdate();\n\n getStatus(1000).then(status => { // depends on config api port\n if (status)\n commonStore.setStatus(status);\n });\n}\n\n\"\"\"\nHere\'s what the above code is doing, explained in a concise way:\n',
// params: {
// maxResponseToken: 500,
// temperature: 0.8,
// topP: 0.7,
// presencePenalty: 0.4,
// frequencyPenalty: 0.4,
// stop: '\\n\\n',
// injectStart: '',
// injectEnd: ''
// }
// }, {
name: 'Werewolf',
prompt: 'There is currently a game of Werewolf with six players, including a Seer (who can check identities at night), two Werewolves (who can choose someone to kill at night), a Bodyguard (who can choose someone to protect at night), two Villagers (with no special abilities), and a game host. Bob will play as Player 1, Alice will play as Players 2-6 and the game host, and they will begin playing together. Every night, the host will ask Bob for his action and simulate the actions of the other players. During the day, the host will oversee the voting process and ask Bob for his vote. \n\nAlice: Next, I will act as the game host and assign everyone their roles, including randomly assigning yours. Then, I will simulate the actions of Players 2-6 and let you know what happens each day. Based on your assigned role, you can tell me your actions and I will let you know the corresponding results each day.\n\nBob: Okay, I understand. Let\'s begin. Please assign me a role. Am I the Seer, Werewolf, Villager, or Bodyguard?\n\nAlice: You are the Seer. Now that night has fallen, please choose a player to check his identity.\n\nBob: Tonight, I want to check Player 2 and find out his role.',
params: {
maxResponseToken: 500,
temperature: 1.2,
topP: 0.4,
presencePenalty: 0.5,
frequencyPenalty: 0.5,
stop: '\\n\\nBob',
injectStart: '\\n\\nAlice: ',
injectEnd: '\\n\\nBob: '
}
}, {
name: 'Instruction',
prompt: 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n# Instruction:\nWrite a story using the following information\n\n# Input:\nA man named Alex chops a tree down\n\n# Response:\n',
params: {
maxResponseToken: 500,
temperature: 1.2,
topP: 0.5,
presencePenalty: 0.4,
frequencyPenalty: 0.4,
stop: '',
injectStart: '',
injectEnd: ''
}
}, {
name: 'Blank',
prompt: '',
params: {
maxResponseToken: 500,
temperature: 1,
topP: 0.5,
presencePenalty: 0.4,
frequencyPenalty: 0.4,
stop: '',
injectStart: '',
injectEnd: ''
}
}];
let completionSseController: AbortController | null = null;
const CompletionPanel: FC = observer(() => {
@@ -220,7 +114,7 @@ const CompletionPanel: FC = observer(() => {
signal: completionSseController?.signal,
onmessage(e) {
scrollToBottom();
if (e.data === '[DONE]') {
if (e.data.trim() === '[DONE]') {
commonStore.setCompletionGenerating(false);
return;
}
@@ -232,8 +126,8 @@ const CompletionPanel: FC = observer(() => {
return;
}
if (data.choices && Array.isArray(data.choices) && data.choices.length > 0) {
answer += data.choices[0].text;
setPrompt(prompt + answer.trim() + params.injectEnd.replaceAll('\\n', '\n'));
answer += data.choices[0]?.text || data.choices[0]?.delta?.content || '';
setPrompt(prompt + answer.replace(/\s+$/, '') + params.injectEnd.replaceAll('\\n', '\n'));
}
},
async onopen(response) {

View File

@@ -0,0 +1,345 @@
import React, { FC, useEffect, useRef } from 'react';
import { observer } from 'mobx-react-lite';
import { WorkHeader } from '../components/WorkHeader';
import { Button, Checkbox, Textarea } from '@fluentui/react-components';
import { Labeled } from '../components/Labeled';
import { ValuedSlider } from '../components/ValuedSlider';
import { useTranslation } from 'react-i18next';
import commonStore, { ModelStatus } from '../stores/commonStore';
import { fetchEventSource } from '@microsoft/fetch-event-source';
import { toast } from 'react-toastify';
import { DialogButton } from '../components/DialogButton';
import { ToolTipButton } from '../components/ToolTipButton';
import { ArrowSync20Regular, Save28Regular } from '@fluentui/react-icons';
import { PlayerElement, VisualizerElement } from 'html-midi-player';
import * as mm from '@magenta/music/esm/core.js';
import { NoteSequence } from '@magenta/music/esm/protobuf.js';
import { defaultCompositionPrompt } from './defaultConfigs';
import { FileExists, OpenFileFolder, OpenSaveFileDialogBytes } from '../../wailsjs/go/backend_golang/App';
import { toastWithButton } from '../utils';
export type CompositionParams = {
prompt: string,
maxResponseToken: number,
temperature: number,
topP: number,
autoPlay: boolean,
useLocalSoundFont: boolean,
midi: ArrayBuffer | null,
ns: NoteSequence | null
}
let compositionSseController: AbortController | null = null;
const CompositionPanel: FC = observer(() => {
const { t } = useTranslation();
const inputRef = useRef<HTMLTextAreaElement>(null);
const port = commonStore.getCurrentModelConfig().apiParameters.apiPort;
const visualizerRef = useRef<VisualizerElement>(null);
const playerRef = useRef<PlayerElement>(null);
const scrollToBottom = () => {
if (inputRef.current)
inputRef.current.scrollTop = inputRef.current.scrollHeight;
};
const params = commonStore.compositionParams;
const setParams = (newParams: Partial<CompositionParams>) => {
commonStore.setCompositionParams({
...commonStore.compositionParams,
...newParams
});
};
const setPrompt = (prompt: string) => {
setParams({
prompt
});
if (!commonStore.compositionGenerating)
generateNs(false);
};
const updateNs = (ns: NoteSequence | null) => {
if (playerRef.current) {
playerRef.current.noteSequence = ns;
playerRef.current.reload();
}
if (visualizerRef.current) {
visualizerRef.current.noteSequence = ns;
visualizerRef.current.reload();
}
};
const setSoundFont = async () => {
let soundUrl: string;
if (commonStore.compositionParams.useLocalSoundFont)
soundUrl = 'assets/sound-font';
else
soundUrl = !commonStore.settings.giteeUpdatesSource ?
`https://raw.githubusercontent.com/josStorer/sgm_plus/master` :
`https://gitee.com/josc146/sgm_plus/raw/master`;
const fallbackUrl = 'https://cdn.jsdelivr.net/gh/josstorer/sgm_plus';
await fetch(soundUrl + '/soundfont.json').then(r => {
if (!r.ok)
soundUrl = fallbackUrl;
}).catch(() => soundUrl = fallbackUrl);
if (playerRef.current) {
playerRef.current.soundFont = soundUrl;
}
};
useEffect(() => {
if (inputRef.current)
inputRef.current.style.height = '100%';
scrollToBottom();
if (playerRef.current && visualizerRef.current) {
playerRef.current.addVisualizer(visualizerRef.current);
playerRef.current.addEventListener('start', () => {
visualizerRef.current?.reload();
});
setSoundFont().then(() => {
updateNs(params.ns);
});
const button = playerRef.current.shadowRoot?.querySelector('.controls .play') as HTMLElement | null;
if (button)
button.style.background = '#f2f5f6';
}
}, []);
const generateNs = (autoPlay: boolean) => {
fetch(commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + '/text-to-midi' :
`http://127.0.0.1:${port}/text-to-midi`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
'text': commonStore.compositionParams.prompt.replaceAll(/<pad>|<start>|<end>/g, '').replaceAll(' ', ' ').trim()
})
}).then(r => {
r.arrayBuffer().then(midi => {
const ns = mm.midiToSequenceProto(midi);
setParams({
midi,
ns
});
updateNs(ns);
if (autoPlay) {
playerRef.current?.start();
}
});
});
};
const onSubmit = (prompt: string) => {
commonStore.setCompositionSubmittedPrompt(prompt);
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl) {
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
commonStore.setCompositionGenerating(false);
return;
}
let answer = '';
compositionSseController = new AbortController();
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/completions
commonStore.settings.apiUrl ?
commonStore.settings.apiUrl + '/v1/completions' :
`http://127.0.0.1:${port}/completions`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${commonStore.settings.apiKey}`
},
body: JSON.stringify({
prompt,
stream: true,
model: commonStore.settings.apiCompletionModelName, // 'text-davinci-003'
max_tokens: params.maxResponseToken,
temperature: params.temperature,
top_p: params.topP
}),
signal: compositionSseController?.signal,
onmessage(e) {
scrollToBottom();
if (e.data.trim() === '[DONE]') {
commonStore.setCompositionGenerating(false);
generateNs(commonStore.compositionParams.autoPlay);
return;
}
let data;
try {
data = JSON.parse(e.data);
} catch (error) {
console.debug('json error', error);
return;
}
if (data.choices && Array.isArray(data.choices) && data.choices.length > 0) {
answer += data.choices[0]?.text || data.choices[0]?.delta?.content || '';
setPrompt(prompt + answer.replace(/\s+$/, ''));
}
},
async onopen(response) {
if (response.status !== 200) {
toast(response.statusText + '\n' + (await response.text()), {
type: 'error'
});
}
},
onclose() {
console.log('Connection closed');
},
onerror(err) {
err = err.message || err;
if (err && !err.includes('ReadableStreamDefaultReader'))
toast(err, {
type: 'error'
});
commonStore.setCompositionGenerating(false);
throw err;
}
});
};
return (
<div className="flex flex-col gap-2 overflow-hidden grow">
<div className="flex flex-col sm:flex-row gap-2 overflow-hidden grow">
<Textarea
ref={inputRef}
className="grow"
value={params.prompt}
onChange={(e) => {
commonStore.setCompositionSubmittedPrompt(e.target.value);
setPrompt(e.target.value);
}}
/>
<div className="flex flex-col gap-1 max-h-48 sm:max-w-sm sm:max-h-full overflow-x-hidden overflow-y-auto p-1">
<Labeled flex breakline label={t('Max Response Token')}
desc={t('By default, the maximum number of tokens that can be answered in a single response, it can be changed by the user by specifying API parameters.')}
content={
<ValuedSlider value={params.maxResponseToken} min={100} max={4100}
step={100}
input
onChange={(e, data) => {
setParams({
maxResponseToken: data.value
});
}} />
} />
<Labeled flex breakline label={t('Temperature')}
desc={t('Sampling temperature, it\'s like giving alcohol to a model, the higher the stronger the randomness and creativity, while the lower, the more focused and deterministic it will be.')}
content={
<ValuedSlider value={params.temperature} min={0} max={2} step={0.1}
input
onChange={(e, data) => {
setParams({
temperature: data.value
});
}} />
} />
<Labeled flex breakline label={t('Top_P')}
desc={t('Just like feeding sedatives to the model. Consider the results of the top n% probability mass, 0.1 considers the top 10%, with higher quality but more conservative, 1 considers all results, with lower quality but more diverse.')}
content={
<ValuedSlider value={params.topP} min={0} max={1} step={0.1} input
onChange={(e, data) => {
setParams({
topP: data.value
});
}} />
} />
<div className="grow" />
<Checkbox className="select-none"
size="large" label={t('Use Local Sound Font')} checked={params.useLocalSoundFont}
onChange={async (_, data) => {
if (data.checked) {
if (!await FileExists('assets/sound-font/accordion/instrument.json')) {
toast(t('Failed to load local sound font, please check if the files exist - assets/sound-font'),
{ type: 'warning' });
return;
}
}
setParams({
useLocalSoundFont: data.checked as boolean
});
setSoundFont();
}} />
<Checkbox className="select-none"
size="large" label={t('Auto Play At The End')} checked={params.autoPlay} onChange={(_, data) => {
setParams({
autoPlay: data.checked as boolean
});
}} />
<div className="flex justify-between gap-2">
<ToolTipButton desc={t('Regenerate')} icon={<ArrowSync20Regular />} onClick={() => {
compositionSseController?.abort();
commonStore.setCompositionGenerating(true);
setPrompt(commonStore.compositionSubmittedPrompt);
onSubmit(commonStore.compositionSubmittedPrompt);
}} />
<DialogButton className="grow" text={t('Reset')} title={t('Reset')}
contentText={t('Are you sure you want to reset this page? It cannot be undone.')}
onConfirm={() => {
commonStore.setCompositionSubmittedPrompt(defaultCompositionPrompt);
setPrompt(defaultCompositionPrompt);
}} />
<Button className="grow" appearance="primary" onClick={() => {
if (commonStore.compositionGenerating) {
compositionSseController?.abort();
commonStore.setCompositionGenerating(false);
generateNs(params.autoPlay);
} else {
commonStore.setCompositionGenerating(true);
onSubmit(params.prompt);
}
}}>{!commonStore.compositionGenerating ? t('Generate') : t('Stop')}</Button>
</div>
</div>
</div>
<div className="flex flex-col">
<div className="ml-auto mr-auto">
<midi-visualizer
ref={visualizerRef}
type="waterfall"
/>
</div>
<div className="flex">
<midi-player
ref={playerRef}
style={{ width: '100%' }}
/>
<Button icon={<Save28Regular />}
onClick={() => {
if (params.midi) {
OpenSaveFileDialogBytes('*.mid', 'music.mid', Array.from(new Uint8Array(params.midi))).then((path) => {
if (path)
toastWithButton(t('File Saved'), t('Open'), () => {
OpenFileFolder(path, false);
});
}).catch((e: any) => {
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
});
} else {
toast(t('No File to save'), { type: 'warning', autoClose: 1500 });
}
}}
>
{t('Save')}
</Button>
</div>
</div>
</div>
);
});
export const Composition: FC = observer(() => {
return (
<div className="flex flex-col gap-1 p-2 h-full overflow-hidden">
<WorkHeader />
<CompositionPanel />
</div>
);
});

View File

@@ -13,8 +13,8 @@ import { Page } from '../components/Page';
import { useNavigate } from 'react-router';
import { RunButton } from '../components/RunButton';
import { updateConfig } from '../apis';
import { ConvertModel, FileExists } from '../../wailsjs/go/backend_golang/App';
import { getStrategy, refreshLocalModels } from '../utils';
import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
import { getStrategy } from '../utils';
import { useTranslation } from 'react-i18next';
import { WindowShow } from '../../wailsjs/runtime/runtime';
import strategyImg from '../assets/images/strategy.jpg';
@@ -253,9 +253,12 @@ export const Configs: FC = observer(() => {
const strategy = getStrategy(selectedConfig);
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(() => {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
refreshLocalModels({ models: commonStore.modelSourceList }, false);
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
if (!await FileExists(newModelPath + '.pth')) {
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
}
}).catch(e => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))

View File

@@ -1,10 +1,10 @@
import React, { FC, useEffect } from 'react';
import React, { FC } from 'react';
import { useTranslation } from 'react-i18next';
import { Page } from '../components/Page';
import { observer } from 'mobx-react-lite';
import commonStore from '../stores/commonStore';
import { Divider, Field, ProgressBar } from '@fluentui/react-components';
import { bytesToGb, bytesToKb, bytesToMb, refreshLocalModels } from '../utils';
import { bytesToGb, bytesToKb, bytesToMb } from '../utils';
import { ToolTipButton } from '../components/ToolTipButton';
import { Folder20Regular, Pause20Regular, Play20Regular } from '@fluentui/react-icons';
import { AddToDownloadList, OpenFileFolder, PauseDownload } from '../../wailsjs/go/backend_golang/App';
@@ -23,12 +23,6 @@ export type DownloadStatus = {
export const Downloads: FC = observer(() => {
const { t } = useTranslation();
const finishedModelsLen = commonStore.downloadList.filter((status) => status.done && status.name.endsWith('.pth')).length;
useEffect(() => {
if (finishedModelsLen > 0)
refreshLocalModels({ models: commonStore.modelSourceList }, false);
console.log('finishedModelsLen:', finishedModelsLen);
}, [finishedModelsLen]);
let displayList = commonStore.downloadList.slice();
const downloadListNames = displayList.map(s => s.name);

View File

@@ -29,7 +29,7 @@ import { botName, Conversation, ConversationMessage, MessageType, userName } fro
import { SelectTabEventHandler } from '@fluentui/react-tabs';
import { Labeled } from '../../components/Labeled';
import commonStore from '../../stores/commonStore';
import logo from '../../assets/images/logo.jpg';
import logo from '../../assets/images/logo.png';
import { observer } from 'mobx-react-lite';
import { MessagesEditor } from './MessagesEditor';
import { ClipboardGetText, ClipboardSetText } from '../../../wailsjs/runtime';

View File

@@ -4,6 +4,7 @@ import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@
import {
ConvertData,
FileExists,
GetPyError,
MergeLora,
OpenFileFolder,
WslCommand,
@@ -17,7 +18,7 @@ import { toast } from 'react-toastify';
import commonStore from '../stores/commonStore';
import { observer } from 'mobx-react-lite';
import { SelectTabEventHandler } from '@fluentui/react-tabs';
import { refreshLocalModels, toastWithButton } from '../utils';
import { checkDependencies, toastWithButton } from '../utils';
import { Section } from '../components/Section';
import { Labeled } from '../components/Labeled';
import { ToolTipButton } from '../components/ToolTipButton';
@@ -36,6 +37,9 @@ import {
} from 'chart.js';
import { Line } from 'react-chartjs-2';
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
import { WindowShow } from '../../wailsjs/runtime';
import { t } from 'i18next';
import { DialogButton } from '../components/DialogButton';
ChartJS.register(
CategoryScale,
@@ -48,15 +52,16 @@ ChartJS.register(
);
const parseLossData = (data: string) => {
const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\d+:\d+)<(\d+:\d+),\s+(\d+.\d+it\/s), loss=(\d+.\d+),[\s\S]*\]/g;
const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\S+)<(\S+),\s+(\S+), loss=(\S+),[\s\S]*\]/g;
const matches = Array.from(data.matchAll(regex));
if (matches.length === 0)
return;
return false;
const lastMatch = matches[matches.length - 1];
const epoch = parseInt(lastMatch[1]);
const loss = parseFloat(lastMatch[8]);
commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`);
addLossDataToChart(epoch, loss);
return true;
};
let chartLine: ChartJSOrUndefined<'line', (number | null)[], string>;
@@ -86,7 +91,7 @@ export type DataProcessParameters = {
vocabPath: string;
}
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'fp32' | 'tf32';
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'tf32';
export type LoraFinetuneParameters = {
baseModel: string;
@@ -139,10 +144,37 @@ const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, t
['headQk', 'boolean', 'Head QK']
];
const showError = (e: any) => {
const msg = e.message || e;
if (msg === 'wsl not running') {
toast(t('WSL is not running, please retry. If it keeps happening, it means you may be using an outdated version of WSL, run "wsl --update" to update.'), { type: 'error' });
} else {
toast(t(msg), { type: 'error', toastId: 'train_error' });
}
};
const errorsMap = Object.entries({
'python3 ./finetune/lora/train.py': 'Memory is not enough, try to increase the virtual memory or use a smaller base model.',
'cuda out of memory': 'VRAM is not enough',
'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
'+= \'+ptx\'': 'You are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
'size mismatch for blocks': 'Size mismatch for blocks. You are attempting to continue training from the LoRA model, but it does not match the base model. Please set LoRA model to None.',
'cuda_home environment variable is not set': 'Matched CUDA is not installed',
'unsupported gpu architecture': 'Matched CUDA is not installed',
'error building extension \'fused_adam\'': 'Matched CUDA is not installed'
});
export const wslHandler = (data: string) => {
if (data) {
addWslMessage(data);
parseLossData(data);
const ok = parseLossData(data);
if (!ok)
for (const [key, value] of errorsMap) {
if (data.toLowerCase().includes(key)) {
showError(value);
return;
}
}
}
};
@@ -187,12 +219,8 @@ const Terminal: FC = observer(() => {
WslStart().then(() => {
addWslMessage('WSL> ' + input);
setInput('');
WslCommand(input).catch((e) => {
toast((e.message || e), { type: 'error' });
});
}).catch((e) => {
toast((e.message || e), { type: 'error' });
});
WslCommand(input).catch(showError);
}).catch(showError);
}
};
@@ -207,9 +235,7 @@ const Terminal: FC = observer(() => {
<Button onClick={() => {
WslStop().then(() => {
toast(t('Command Stopped'), { type: 'success' });
}).catch((e) => {
toast((e.message || e), { type: 'error' });
});
}).catch(showError);
}}>
{t('Stop')}
</Button>
@@ -250,13 +276,30 @@ const LoraFinetune: FC = observer(() => {
});
}, []);
const StartLoraFinetune = () => {
const StartLoraFinetune = async () => {
const ok = await checkDependencies(navigate);
if (!ok)
return;
const convertedDataPath = './finetune/json2binidx_tool/data/' +
dataParams.dataPath.replace(/[\/\\]$/, '').split(/[\/\\]/).pop()!.split('.')[0] +
'_text_document';
if (!await FileExists(convertedDataPath + '.idx')) {
toast(t('Please convert data first.'), { type: 'error' });
return;
}
WslIsEnabled().then(() => {
WslStart().then(async () => {
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
if (!await FileExists(convertedDataPath + '.idx')) {
toast(t('Please convert data first.'), { type: 'error' });
return;
WslStart().then(() => {
setTimeout(WindowShow, 1000);
let ctxLen = loraParams.ctxLen;
if (dataParams.dataPath === 'finetune/data/sample.jsonl') {
ctxLen = 150;
toast(t('You are using sample data for training. For formal training, please make sure to create your own jsonl file.'), {
type: 'info',
autoClose: 6000
});
}
commonStore.setChartData({
@@ -272,12 +315,13 @@ const LoraFinetune: FC = observer(() => {
});
WslCommand(`export cnMirror=${commonStore.settings.cnMirror ? '1' : '0'} ` +
`&& export loadModel=models/${loraParams.baseModel} ` +
`&& sed -i 's/\\r$//' finetune/install-wsl-dep-and-train.sh ` +
`&& chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` +
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
`--data_file ${convertedDataPath} ` +
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
`--ctx_len ${loraParams.ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
`--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` +
@@ -285,15 +329,18 @@ const LoraFinetune: FC = observer(() => {
`--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` +
`--devices ${loraParams.devices} --precision ${loraParams.precision} ` +
`--grad_cp ${loraParams.gradCp ? '1' : '0'} ` +
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e) => {
toast((e.message || e), { type: 'error' });
});
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch(showError);
}).catch(e => {
const msg = e.message || e;
if (msg === 'ubuntu not found') {
WindowShow();
toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => {
WslInstallUbuntu().then(() => {
toast(t('Please install Ubuntu using Microsoft Store'), { type: 'info', autoClose: 6000 });
WindowShow();
toast(t('Please install Ubuntu using Microsoft Store, after installation click the Open button in Microsoft Store and then click the Train button'), {
type: 'info',
autoClose: 10000
});
});
});
}
@@ -302,15 +349,15 @@ const LoraFinetune: FC = observer(() => {
const msg = e.message || e;
const enableWsl = (forceMode: boolean) => {
WindowShow();
toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => {
WslEnable(forceMode).then(() => {
WindowShow();
toast(t('After installation, please restart your computer to enable WSL'), {
type: 'info',
autoClose: false
});
}).catch(e => {
toast((e.message || e), { type: 'error' });
});
}).catch(showError);
});
};
@@ -319,7 +366,7 @@ const LoraFinetune: FC = observer(() => {
} else if (msg.includes('wsl.state: The system cannot find the file')) {
enableWsl(true);
} else {
toast(msg, { type: 'error' });
showError(msg);
}
});
};
@@ -357,32 +404,46 @@ const LoraFinetune: FC = observer(() => {
title={t('Data Process')}
content={
<div className="flex flex-col gap-2">
<Labeled flex label={t('Data Path')}
content={
<div className="grow flex gap-2">
<Input className="grow ml-2" value={dataParams.dataPath}
onChange={(e, data) => {
setDataParams({ dataPath: data.value });
}} />
<ToolTipButton desc={t('Open Folder')} icon={<Folder20Regular />} onClick={() => {
OpenFileFolder(dataParams.dataPath, false);
}} />
</div>
} />
<div className="flex gap-2 items-center">
{t('Data Path')}
<Input className="grow" style={{ minWidth: 0 }} value={dataParams.dataPath}
onChange={(e, data) => {
setDataParams({ dataPath: data.value });
}} />
<DialogButton text={t('Help')} title={t('Help')} markdown
contentText={t('The data path should be a directory or a file in jsonl format (more formats will be supported in the future).\n\n' +
'When you provide a directory path, all the txt files within that directory will be automatically converted into training data. ' +
'This is commonly used for large-scale training in writing, code generation, or knowledge bases.\n\n' +
'The jsonl format file can be referenced at https://github.com/Abel2076/json2binidx_tool/blob/main/sample.jsonl.\n' +
'You can also write it similar to OpenAI\'s playground format, as shown in https://platform.openai.com/playground/p/default-chat.\n' +
'Even for multi-turn conversations, they must be written in a single line using `\\n` to indicate line breaks. ' +
'If they are different dialogues or topics, they should be written in separate lines.')} />
<ToolTipButton desc={t('Open Folder')} icon={<Folder20Regular />} onClick={() => {
OpenFileFolder(dataParams.dataPath, false);
}} />
</div>
<div className="flex gap-2 items-center">
{t('Vocab Path')}
<Input className="grow" style={{ minWidth: 0 }} value={dataParams.vocabPath}
onChange={(e, data) => {
setDataParams({ vocabPath: data.value });
}} />
<Button appearance="secondary" size="large" onClick={() => {
ConvertData(commonStore.settings.customPythonPath, dataParams.dataPath,
'./finetune/json2binidx_tool/data/' + dataParams.dataPath.split('/').pop()!.split('.')[0],
dataParams.vocabPath).then(() => {
toast(t('Convert Data successfully'), { type: 'success' });
}).catch((e) => {
toast((e.message || e), { type: 'error' });
});
<Button appearance="secondary" onClick={async () => {
const ok = await checkDependencies(navigate);
if (!ok)
return;
const outputPrefix = './finetune/json2binidx_tool/data/' +
dataParams.dataPath.replace(/[\/\\]$/, '').split(/[\/\\]/).pop()!.split('.')[0];
ConvertData(commonStore.settings.customPythonPath,
dataParams.dataPath.replaceAll('\\', '/'),
outputPrefix,
dataParams.vocabPath).then(async () => {
if (!await FileExists(outputPrefix + '_text_document.idx')) {
toast(t('Failed to convert data') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(t('Convert Data successfully'), { type: 'success' });
}
}).catch(showError);
}}>{t('Convert')}</Button>
</div>
</div>
@@ -424,15 +485,24 @@ const LoraFinetune: FC = observer(() => {
<option key={index} value={name}>{name}</option>
)}
</Select>
<Button onClick={() => {
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
toast(t('Merge model successfully'), { type: 'success' });
refreshLocalModels({ models: commonStore.modelSourceList }, false);
}).catch((e) => {
toast((e.message || e), { type: 'error' });
});
<Button onClick={async () => {
const ok = await checkDependencies(navigate);
if (!ok)
return;
if (loraParams.loraLoad) {
const outputPath = `models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`;
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
outputPath).then(async () => {
if (!await FileExists(outputPath)) {
toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' });
} else {
toast(t('Merge model successfully'), { type: 'success' });
}
}).catch(showError);
} else {
toast(t('Please select a LoRA model'), { type: 'info' });
}
}}>{t('Merge Model')}</Button>
</div>
{
@@ -474,7 +544,6 @@ const LoraFinetune: FC = observer(() => {
>
<Option>bf16</Option>
<Option>fp16</Option>
<Option>fp32</Option>
<Option>tf32</Option>
</Dropdown>
: <div />
@@ -491,9 +560,7 @@ const LoraFinetune: FC = observer(() => {
<Button appearance="secondary" size="large" onClick={() => {
WslStop().then(() => {
toast(t('Command Stopped'), { type: 'success' });
}).catch((e) => {
toast((e.message || e), { type: 'error' });
});
}).catch(showError);
}}>{t('Stop')}</Button>
<Button appearance="primary" size="large" onClick={StartLoraFinetune}>{t('Train')}</Button>
</div>

View File

@@ -8,6 +8,7 @@ import {
DocumentSettings20Regular,
Home20Regular,
Info20Regular,
MusicNote220Regular,
Settings20Regular,
Storage20Regular
} from '@fluentui/react-icons';
@@ -19,6 +20,7 @@ import { Settings } from './Settings';
import { About } from './About';
import { Downloads } from './Downloads';
import { Completion } from './Completion';
import { Composition } from './Composition';
type NavigationItem = {
label: string;
@@ -50,6 +52,13 @@ export const pages: NavigationItem[] = [
element: <Completion />,
top: true
},
{
label: 'Composition',
path: '/composition',
icon: <MusicNote220Regular />,
element: <Composition />,
top: true
},
{
label: 'Configs',
path: '/configs',

View File

@@ -1,10 +1,10 @@
import commonStore, { Platform } from './stores/commonStore';
import { GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshModels } from './utils';
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshLocalModels, refreshModels } from './utils';
import { getStatus } from './apis';
import { EventsOn } from '../wailsjs/runtime';
import manifest from '../../manifest.json';
import { defaultModelConfigs, defaultModelConfigsMac } from './pages/defaultModelConfigs';
import { defaultModelConfigs, defaultModelConfigsMac } from './pages/defaultConfigs';
import { Preset } from './pages/PresetsManager/PresetsButton';
import { wslHandler } from './pages/Train';
@@ -18,6 +18,7 @@ export async function startup() {
EventsOn('wslerr', (e) => {
console.log(e);
});
initLocalModelsNotify();
initLoraModels();
initPresets();
@@ -59,6 +60,9 @@ async function initConfig() {
if (configData.dataProcessParams)
commonStore.setDataProcessParams(configData.dataProcessParams, false);
if (configData.loraFinetuneParams)
commonStore.setLoraFinetuneParameters(configData.loraFinetuneParams, false);
if (configData.modelConfigs && Array.isArray(configData.modelConfigs))
commonStore.setModelConfigs(configData.modelConfigs, false);
else throw new Error('Invalid config.json');
@@ -106,3 +110,10 @@ async function initLoraModels() {
refreshLoraModels();
});
}
async function initLocalModelsNotify() {
EventsOn('fsnotify', (data: string) => {
if (data.includes('models') && !data.includes('lora-models'))
refreshLocalModels({ models: commonStore.modelSourceList }, false); //TODO fix bug that only add models
});
}

View File

@@ -11,11 +11,12 @@ import { IntroductionContent } from '../pages/Home';
import { AboutContent } from '../pages/About';
import i18n from 'i18next';
import { CompletionPreset } from '../pages/Completion';
import { defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultModelConfigs';
import { defaultCompositionPrompt, defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultConfigs';
import commonStore from './commonStore';
import { Preset } from '../pages/PresetsManager/PresetsButton';
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
import { ChartData } from 'chart.js';
import { CompositionParams } from '../pages/Composition';
export enum ModelStatus {
Offline,
@@ -57,6 +58,19 @@ class CommonStore {
completionPreset: CompletionPreset | null = null;
completionGenerating: boolean = false;
completionSubmittedPrompt: string = '';
// composition
compositionParams: CompositionParams = {
prompt: defaultCompositionPrompt,
maxResponseToken: 200,
temperature: 1,
topP: 0.8,
autoPlay: true,
useLocalSoundFont: false,
midi: null,
ns: null
};
compositionGenerating: boolean = false;
compositionSubmittedPrompt: string = defaultCompositionPrompt;
// configs
currentModelConfigIndex: number = 0;
modelConfigs: ModelConfig[] = [];
@@ -78,10 +92,10 @@ class CommonStore {
loraFinetuneParams: LoraFinetuneParameters = {
baseModel: '',
ctxLen: 1024,
epochSteps: 1000,
epochSteps: 200,
epochCount: 20,
epochBegin: 0,
epochSave: 5,
epochSave: 2,
microBsz: 1,
accumGradBatches: 8,
preFfn: false,
@@ -267,6 +281,18 @@ class CommonStore {
this.completionSubmittedPrompt = value;
}
setCompositionParams(value: CompositionParams) {
this.compositionParams = value;
}
setCompositionGenerating(value: boolean) {
this.compositionGenerating = value;
}
setCompositionSubmittedPrompt(value: string) {
this.compositionSubmittedPrompt = value;
}
setWslStdout(value: string) {
this.wslStdout = value;
}

View File

@@ -28,6 +28,7 @@ body {
/* Works on Chrome, Edge, and Safari */
*::-webkit-scrollbar {
width: 9px;
height: 9px;
}
*::-webkit-scrollbar-thumb {
@@ -92,3 +93,22 @@ body {
}
}
}
midi-player {
&::part(control-panel) {
background: none;
}
}
midi-visualizer {
$instrument-colors: #007bff, #20c997, #dc3545, #6610f2, #ffc107, #e83e8c, #17a2b8, #fd7e14, #28a745;
svg {
@for $i from 0 to 200 {
$color: nth($instrument-colors, ($i % length($instrument-colors)) + 1);
rect.note[data-instrument="#{$i}"] {
fill: $color;
}
}
}
}

View File

@@ -0,0 +1,9 @@
declare module JSX {
import { PlayerElement } from 'html-midi-player';
import { VisualizerElement } from 'html-midi-player';
interface IntrinsicElements {
'midi-player': PlayerElement;
'midi-visualizer': VisualizerElement;
}
}

View File

@@ -1,6 +1,9 @@
import {
AddToDownloadList,
CopyFile,
DeleteFile,
DepCheck,
InstallPyDep,
ListDirFiles,
ReadFileInfo,
ReadJson,
@@ -8,7 +11,7 @@ import {
UpdateApp
} from '../../wailsjs/go/backend_golang/App';
import manifest from '../../../manifest.json';
import commonStore from '../stores/commonStore';
import commonStore, { ModelStatus } from '../stores/commonStore';
import { toast } from 'react-toastify';
import { t } from 'i18next';
import { ToastOptions } from 'react-toastify/dist/types';
@@ -18,6 +21,8 @@ import { ModelSourceItem } from '../pages/Models';
import { ModelConfig, ModelParameters } from '../pages/Configs';
import { DownloadStatus } from '../pages/Downloads';
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
import { BrowserOpenURL, WindowShow } from '../../wailsjs/runtime';
import { NavigateFunction } from 'react-router';
export type Cache = {
version: string
@@ -347,6 +352,56 @@ export async function checkUpdate(notifyEvenLatest: boolean = false) {
});
}
export const checkDependencies = async (navigate: NavigateFunction) => {
if (!commonStore.depComplete) {
let depErrorMsg = '';
await DepCheck(commonStore.settings.customPythonPath).catch((e) => {
depErrorMsg = e.message || e;
WindowShow();
if (depErrorMsg === 'python zip not found') {
toastWithButton(t('Python target not found, would you like to download it?'), t('Download'), () => {
toastWithButton(`${t('Downloading')} Python`, t('Check'), () => {
navigate({ pathname: '/downloads' });
}, { autoClose: 3000 });
AddToDownloadList('python-3.10.11-embed-amd64.zip', 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip');
});
} else if (depErrorMsg.includes('DepCheck Error')) {
if (depErrorMsg.includes('vc_redist') || depErrorMsg.includes('DLL load failed while importing')) {
toastWithButton(t('Microsoft Visual C++ Redistributable is not installed, would you like to download it?'), t('Download'), () => {
BrowserOpenURL('https://aka.ms/vs/16/release/vc_redist.x64.exe');
});
} else {
toast(depErrorMsg, { type: 'info', position: 'bottom-left' });
if (commonStore.platform != 'linux')
toastWithButton(t('Python dependencies are incomplete, would you like to install them?'), t('Install'), () => {
InstallPyDep(commonStore.settings.customPythonPath, commonStore.settings.cnMirror).catch((e) => {
const errMsg = e.message || e;
toast(t('Error') + ' - ' + errMsg, { type: 'error' });
});
setTimeout(WindowShow, 1000);
}, {
autoClose: 8000
});
else
toastWithButton(t('On Linux system, you must manually install python dependencies.'), t('Check'), () => {
BrowserOpenURL('https://github.com/josStorer/RWKV-Runner/blob/master/build/linux/Readme_Install.txt');
});
}
} else {
toast(depErrorMsg, { type: 'error' });
}
});
if (depErrorMsg) {
commonStore.setStatus({ status: ModelStatus.Offline });
return false;
}
commonStore.setDepComplete(true);
if (commonStore.platform === 'windows')
CopyFile('./backend-python/wkv_cuda_utils/wkv_cuda_model.py', './py310/Lib/site-packages/rwkv/model.py');
}
return true;
};
export function toastWithButton(text: string, buttonText: string, onClickButton: () => void, options?: ToastOptions) {
let triggered = false;
const id = toast(

View File

@@ -22,6 +22,8 @@ export function FileExists(arg1:string):Promise<boolean>;
export function GetPlatform():Promise<string>;
export function GetPyError():Promise<string>;
export function InstallPyDep(arg1:string,arg2:boolean):Promise<string>;
export function ListDirFiles(arg1:string):Promise<Array<backend_golang.FileInfo>>;
@@ -32,6 +34,8 @@ export function OpenFileFolder(arg1:string,arg2:boolean):Promise<void>;
export function OpenSaveFileDialog(arg1:string,arg2:string,arg3:string):Promise<string>;
export function OpenSaveFileDialogBytes(arg1:string,arg2:string,arg3:Array<number>):Promise<string>;
export function PauseDownload(arg1:string):Promise<void>;
export function ReadFileInfo(arg1:string):Promise<backend_golang.FileInfo>;

View File

@@ -42,6 +42,10 @@ export function GetPlatform() {
return window['go']['backend_golang']['App']['GetPlatform']();
}
export function GetPyError() {
return window['go']['backend_golang']['App']['GetPyError']();
}
export function InstallPyDep(arg1, arg2) {
return window['go']['backend_golang']['App']['InstallPyDep'](arg1, arg2);
}
@@ -62,6 +66,10 @@ export function OpenSaveFileDialog(arg1, arg2, arg3) {
return window['go']['backend_golang']['App']['OpenSaveFileDialog'](arg1, arg2, arg3);
}
export function OpenSaveFileDialogBytes(arg1, arg2, arg3) {
return window['go']['backend_golang']['App']['OpenSaveFileDialogBytes'](arg1, arg2, arg3);
}
export function PauseDownload(arg1) {
return window['go']['backend_golang']['App']['PauseDownload'](arg1);
}

41
main.go
View File

@@ -2,6 +2,8 @@ package main
import (
"embed"
"fmt"
"net/http"
"os"
"runtime/debug"
"strings"
@@ -14,6 +16,27 @@ import (
"github.com/wailsapp/wails/v2/pkg/options/windows"
)
type FileLoader struct {
http.Handler
}
func NewFileLoader() *FileLoader {
return &FileLoader{}
}
func (h *FileLoader) ServeHTTP(res http.ResponseWriter, req *http.Request) {
var err error
requestedFilename := strings.TrimPrefix(req.URL.Path, "/")
println("Requesting file:", requestedFilename)
fileData, err := os.ReadFile(requestedFilename)
if err != nil {
res.WriteHeader(http.StatusBadRequest)
res.Write([]byte(fmt.Sprintf("Could not load file %s", requestedFilename)))
}
res.Write(fileData)
}
//go:embed all:frontend/dist
var assets embed.FS
@@ -29,19 +52,20 @@ var py embed.FS
//go:embed finetune
var finetune embed.FS
//go:embed midi
var midi embed.FS
//go:embed assets/sound-font
var midiAssets embed.FS
func main() {
if buildInfo, ok := debug.ReadBuildInfo(); !ok || strings.Contains(buildInfo.String(), "-ldflags") {
backend.CopyEmbed(cyac)
backend.CopyEmbed(cyacInfo)
backend.CopyEmbed(py)
backend.CopyEmbed(finetune)
os.Mkdir("models", os.ModePerm)
os.Mkdir("lora-models", os.ModePerm)
}
f, err := os.Create("lora-models/train_log.txt")
if err == nil {
f.Close()
backend.CopyEmbed(midi)
backend.CopyEmbed(midiAssets)
}
// Create an instance of the app structure
@@ -71,7 +95,8 @@ func main() {
IsZoomControlEnabled: true,
},
AssetServer: &assetserver.Options{
Assets: assets,
Assets: assets,
Handler: NewFileLoader(),
},
OnStartup: app.OnStartup,
Bind: []any{

View File

@@ -1,12 +1,12 @@
{
"version": "1.3.1",
"version": "1.4.0",
"introduction": {
"en": "RWKV is an open-source, commercially usable large language model with high flexibility and great potential for development.\n### About This Tool\nThis tool aims to lower the barrier of entry for using large language models, making it accessible to everyone. It provides fully automated dependency and model management. You simply need to click and run, following the instructions, to deploy a local large language model. The tool itself is very compact and only requires a single executable file for one-click deployment.\nAdditionally, this tool offers an interface that is fully compatible with the OpenAI API. This means you can use any ChatGPT client as a client for RWKV, enabling capability expansion beyond just chat functionality.\n### Preset Configuration Rules at the Bottom\nThis tool comes with a series of preset configurations to reduce complexity. The naming rules for each configuration represent the following in order: device - required VRAM/memory - model size - model language.\nFor example, \"GPU-8G-3B-EN\" indicates that this configuration is for a graphics card with 8GB of VRAM, a model size of 3 billion parameters, and it uses an English language model.\nLarger model sizes have higher performance and VRAM requirements. Among configurations with the same model size, those with higher VRAM usage will have faster runtime.\nFor example, if you have 12GB of VRAM but running the \"GPU-12G-7B-EN\" configuration is slow, you can downgrade to \"GPU-8G-3B-EN\" for a significant speed improvement.\n### About RWKV\nRWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the \"GPT\" mode to quickly compute the hidden state for the \"RNN\" mode.<br/>So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, \"infinite\" ctx_len, and free sentence embedding (using the final hidden state).",
"zh": "RWKV是一个开源且允许商用的大语言模型灵活性很高且极具发展潜力。\n### 关于本工具\n本工具旨在降低大语言模型的使用门槛做到人人可用本工具提供了全自动化的依赖和模型管理你只需要直接点击运行跟随引导即可完成本地大语言模型的部署工具本身体积极小只需要一个exe即可完成一键部署。\n此外本工具提供了与OpenAI API完全兼容的接口这意味着你可以把任意ChatGPT客户端用作RWKV的客户端实现能力拓展而不局限于聊天。\n### 底部的预设配置规则\n本工具内置了一系列预设配置以降低使用难度每个配置名的规则依次代表着设备-所需显存/内存-模型规模-模型语言。\n例如GPU-8G-3B-CN表示该配置用于显卡需要8G显存模型规模为30亿参数使用的是中文模型。\n模型规模越大性能要求越高显存要求也越高而同样模型规模的配置中显存占用越高的运行速度越快。\n例如当你有12G显存但运行GPU-12G-7B-CN配置速度比较慢可降级成GPU-8G-3B-CN将会大幅提速。\n### 关于RWKV\nRWKV是具有Transformer级别LLM性能的RNN也可以像GPT Transformer一样直接进行训练可并行化。而且它是100% attention-free的。你只需在位置t处获得隐藏状态即可计算位置t + 1处的状态。你可以使用“GPT”模式快速计算用于“RNN”模式的隐藏状态。\n因此它将RNN和Transformer的优点结合起来 - 高性能、快速推理、节省显存、快速训练、“无限”上下文长度以及免费的语句嵌入(使用最终隐藏状态)。"
},
"about": {
"en": "<div align=\"center\">\n\nProject Source Code:\nhttps://github.com/josStorer/RWKV-Runner\nAuthor: [@josStorer](https://github.com/josStorer)\nFAQs: https://github.com/josStorer/RWKV-Runner/wiki/FAQs\n\nRelated Repositories:\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\n\n</div>",
"zh": "<div align=\"center\">\n\n本项目源码:\nhttps://github.com/josStorer/RWKV-Runner\n作者: [@josStorer](https://github.com/josStorer)\n演示与常见问题说明视频: https://www.bilibili.com/video/BV1hM4y1v76R\n疑难解答: https://www.bilibili.com/read/cv23921171\n\n相关仓库:\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\n\n</div>"
"en": "<div align=\"center\">\n\nProject Source Code:\nhttps://github.com/josStorer/RWKV-Runner\nAuthor: [@josStorer](https://github.com/josStorer)\nFAQs: https://github.com/josStorer/RWKV-Runner/wiki/FAQs\n\nRelated Repositories:\nRWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\nRWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA\nMIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer\n\n</div>",
"zh": "<div align=\"center\">\n\n本项目源码:\nhttps://github.com/josStorer/RWKV-Runner\n作者: [@josStorer](https://github.com/josStorer)\n演示与常见问题说明视频: https://www.bilibili.com/video/BV1hM4y1v76R\n疑难解答: https://www.bilibili.com/read/cv23921171\n\n相关仓库:\nRWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\nRWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA\nMIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer\n\n</div>"
},
"programFiles": [
{
@@ -292,6 +292,42 @@
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-7B-v1-20230626-ctx4096.pth",
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-7B-v1-20230626-ctx4096.pth"
},
{
"name": "RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth",
"desc": {
"en": "Global Languages 7B v1 Enhanced Chinese",
"zh": "全球语言 7B v1 中文增强"
},
"size": 15035393458,
"SHA256": "52d33e8352a40158d21425fee4f68df1515d6324056f788d2c78a366ef578ffa",
"lastUpdated": "2023-07-09T18:23:33",
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth",
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth"
},
{
"name": "Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
"desc": {
"en": "Global Languages 7B v1 Enhanced Chinese Ctx32k Summary Ability",
"zh": "全球语言 7B v1 中文增强 32k上下文 总结能力"
},
"size": 15035391543,
"SHA256": "1bd1de8cdbd56b67e1374588fe5d202884049c71278ffcb12f5c4efbdb422ee1",
"lastUpdated": "2023-07-20T06:11:29",
"url": "https://huggingface.co/xiaol/readflow-rwkv-4-world-ctx32k/blob/main/Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
"downloadUrl": "https://huggingface.co/xiaol/readflow-rwkv-4-world-ctx32k/resolve/main/Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth"
},
{
"name": "RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth",
"desc": {
"en": "Global Languages 7B v1 Enhanced Japanese",
"zh": "全球语言 7B v1 日文增强"
},
"size": 15035393458,
"SHA256": "3e4c7664ce893ac1f6bb59cd76664fb5c872cb076bb82dbd534db0555b6e9fa5",
"lastUpdated": "2023-07-18T20:01:12",
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth",
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth"
},
{
"name": "RWKV-4-Novel-7B-v1-ChnEng-ChnPro-20230410-ctx4096.pth",
"desc": {
@@ -514,6 +550,30 @@
"lastUpdated": "2023-05-23T11:22:41",
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth",
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth"
},
{
"name": "RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth",
"desc": {
"en": "Music 120M v1",
"zh": "作曲 120M v1"
},
"size": 239224753,
"SHA256": "161d27dcf50d0958d230601ba1e0f8e7dd9c236105e92d2b833496412ace430c",
"lastUpdated": "2023-07-15T08:03:36",
"url": "https://huggingface.co/BlinkDL/rwkv-4-music/blob/main/RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth",
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth"
},
{
"name": "RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth",
"desc": {
"en": "Music 560M v1",
"zh": "作曲 560M v1"
},
"size": 1130577457,
"SHA256": "62b21841b24af38ef176e9e9d895d9fff730cea8aa0623f53a1784d74ce828d6",
"lastUpdated": "2023-07-17T15:02:08",
"url": "https://huggingface.co/BlinkDL/rwkv-4-music/blob/main/RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth",
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth"
}
]
}

1
midi/sample.txt Normal file
View File

@@ -0,0 +1 @@
<start> p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:2d:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:2d:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:1f:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:26:a g:39:a g:39:a g:3e:a g:3e:a g:42:a g:42:a pi:39:a pi:3e:a pi:42:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 <end>

View File

@@ -2,5 +2,8 @@
- ^backend-python/wkv_cuda_utils/
- ^backend-python/get-pip\.py
- ^backend-python/convert_model\.py
- ^backend-python/utils/midi\.py
- ^build/
- ^finetune/lora/
- ^finetune/json2binidx_tool/
- ^frontend/wailsjs/